Merge branch 'master' into xsn/server_model_management_v1_2
This commit is contained in:
commit
919d3f8cbf
|
|
@ -26,7 +26,6 @@
|
|||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -60,6 +59,14 @@
|
|||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||
|
||||
common_time_meas::~common_time_meas() {
|
||||
if (t_start_us >= 0) {
|
||||
t_acc += ggml_time_us() - t_start_us;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// CPU utils
|
||||
//
|
||||
|
|
|
|||
|
|
@ -2,17 +2,15 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "ggml-opt.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <cmath>
|
||||
|
||||
#include "ggml-opt.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#define DIRECTORY_SEPARATOR '\\'
|
||||
|
|
@ -30,6 +28,15 @@
|
|||
|
||||
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
||||
|
||||
struct common_time_meas {
|
||||
common_time_meas(int64_t & t_acc, bool disable = false);
|
||||
~common_time_meas();
|
||||
|
||||
const int64_t t_start_us;
|
||||
|
||||
int64_t & t_acc;
|
||||
};
|
||||
|
||||
struct common_adapter_lora_info {
|
||||
std::string path;
|
||||
float scale;
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@
|
|||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
|
||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||
// TODO: deduplicate with llama-impl.h
|
||||
|
|
@ -112,6 +113,13 @@ struct common_sampler {
|
|||
|
||||
llama_token_data_array cur_p;
|
||||
|
||||
void reset() {
|
||||
prev.clear();
|
||||
|
||||
llama_sampler_reset(grmr);
|
||||
llama_sampler_reset(chain);
|
||||
}
|
||||
|
||||
void set_logits(struct llama_context * ctx, int idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
|
||||
|
|
@ -128,6 +136,12 @@ struct common_sampler {
|
|||
|
||||
cur_p = { cur.data(), cur.size(), -1, false };
|
||||
}
|
||||
|
||||
common_time_meas tm() {
|
||||
return common_time_meas(t_total_us, params.no_perf);
|
||||
}
|
||||
|
||||
mutable int64_t t_total_us = 0;
|
||||
};
|
||||
|
||||
std::string common_params_sampling::print() const {
|
||||
|
|
@ -298,6 +312,8 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
|||
}
|
||||
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
if (accept_grammar) {
|
||||
llama_sampler_accept(gsmpl->grmr, token);
|
||||
}
|
||||
|
|
@ -308,9 +324,7 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
|
|||
}
|
||||
|
||||
void common_sampler_reset(struct common_sampler * gsmpl) {
|
||||
llama_sampler_reset(gsmpl->grmr);
|
||||
|
||||
llama_sampler_reset(gsmpl->chain);
|
||||
gsmpl->reset();
|
||||
}
|
||||
|
||||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||
|
|
@ -327,16 +341,54 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
|||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
||||
// TODO: measure grammar performance
|
||||
|
||||
const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0;
|
||||
|
||||
llama_perf_sampler_data data_smpl;
|
||||
llama_perf_context_data data_ctx;
|
||||
|
||||
memset(&data_smpl, 0, sizeof(data_smpl));
|
||||
memset(&data_ctx, 0, sizeof(data_ctx));
|
||||
|
||||
if (gsmpl) {
|
||||
llama_perf_sampler_print(gsmpl->chain);
|
||||
auto & data = data_smpl;
|
||||
|
||||
data = llama_perf_sampler(gsmpl->chain);
|
||||
|
||||
// note: the sampling time includes the samplers time + extra time spent in common/sampling
|
||||
LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms);
|
||||
LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample);
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
llama_perf_context_print(ctx);
|
||||
auto & data = data_ctx;
|
||||
|
||||
data = llama_perf_context(ctx);
|
||||
|
||||
const double t_end_ms = 1e-3 * ggml_time_us();
|
||||
|
||||
const double t_total_ms = t_end_ms - data.t_start_ms;
|
||||
const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms);
|
||||
const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms;
|
||||
|
||||
LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
|
||||
LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
|
||||
LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
||||
LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
||||
LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc);
|
||||
LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused);
|
||||
|
||||
llama_memory_breakdown_print(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
auto & grmr = gsmpl->grmr;
|
||||
|
|
@ -428,6 +480,8 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
|||
// helpers
|
||||
|
||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
auto * res = &gsmpl->cur_p;
|
||||
|
||||
if (do_sort && !res->sorted) {
|
||||
|
|
|
|||
|
|
@ -1673,11 +1673,9 @@ class GPTNeoXModel(TextModel):
|
|||
model_arch = gguf.MODEL_ARCH.GPTNEOX
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(
|
||||
int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
|
||||
|
|
@ -1735,7 +1733,7 @@ class BloomModel(TextModel):
|
|||
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
|
||||
self.gguf_writer.add_embedding_length(n_embed)
|
||||
self.gguf_writer.add_feed_forward_length(4 * n_embed)
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head)
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
|
|
@ -1798,10 +1796,9 @@ class MPTModel(TextModel):
|
|||
self.gguf_writer.add_unk_token_id(0)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["n_layers"]
|
||||
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
|
||||
self.gguf_writer.add_head_count(self.hparams["n_heads"])
|
||||
if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
|
||||
|
|
@ -1834,7 +1831,6 @@ class OrionModel(TextModel):
|
|||
self._set_vocab_sentencepiece()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
|
||||
|
|
@ -1852,7 +1848,7 @@ class OrionModel(TextModel):
|
|||
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
||||
self.gguf_writer.add_context_length(ctx_length)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||
|
|
@ -1869,7 +1865,6 @@ class BaichuanModel(TextModel):
|
|||
self._set_vocab_sentencepiece()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
|
||||
|
|
@ -1886,7 +1881,7 @@ class BaichuanModel(TextModel):
|
|||
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
||||
self.gguf_writer.add_context_length(ctx_length)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
|
|
@ -1993,7 +1988,6 @@ class XverseModel(TextModel):
|
|||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
|
||||
|
|
@ -2010,7 +2004,7 @@ class XverseModel(TextModel):
|
|||
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
||||
self.gguf_writer.add_context_length(ctx_length)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
|
|
@ -2053,10 +2047,6 @@ class FalconModel(TextModel):
|
|||
model_arch = gguf.MODEL_ARCH.FALCON
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams.get("num_hidden_layers")
|
||||
if block_count is None:
|
||||
block_count = self.hparams["n_layer"] # old name
|
||||
|
||||
n_head = self.hparams.get("num_attention_heads")
|
||||
if n_head is None:
|
||||
n_head = self.hparams["n_head"] # old name
|
||||
|
|
@ -2069,7 +2059,7 @@ class FalconModel(TextModel):
|
|||
self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
|
|
@ -2107,12 +2097,10 @@ class StarCoderModel(TextModel):
|
|||
model_arch = gguf.MODEL_ARCH.STARCODER
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["n_layer"]
|
||||
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||
self.gguf_writer.add_head_count_kv(1)
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
|
|
@ -2142,14 +2130,12 @@ class RefactModel(TextModel):
|
|||
multiple_of = 256
|
||||
ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
block_count = self.hparams["n_layer"]
|
||||
|
||||
# refact uses Alibi. So this is from config.json which might be used by training.
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
|
||||
self.gguf_writer.add_feed_forward_length(ff_dim)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||
self.gguf_writer.add_head_count_kv(1)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
|
||||
|
|
@ -2196,11 +2182,10 @@ class StableLMModel(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
|
||||
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
|
||||
|
|
@ -3151,7 +3136,7 @@ class DbrxModel(TextModel):
|
|||
def set_gguf_parameters(self):
|
||||
ffn_config = self.hparams["ffn_config"]
|
||||
attn_config = self.hparams["attn_config"]
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
|
||||
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
|
|
@ -3353,7 +3338,7 @@ class QwenModel(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
|
||||
|
|
@ -4384,7 +4369,7 @@ class GPT2Model(TextModel):
|
|||
model_arch = gguf.MODEL_ARCH.GPT2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_context_length(self.hparams["n_ctx"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||
|
|
@ -4416,8 +4401,6 @@ class Phi2Model(TextModel):
|
|||
model_arch = gguf.MODEL_ARCH.PHI2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
|
||||
|
||||
rot_pct = self.find_hparam(["partial_rotary_factor"])
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
|
|
@ -4426,7 +4409,7 @@ class Phi2Model(TextModel):
|
|||
|
||||
self.gguf_writer.add_embedding_length(n_embd)
|
||||
self.gguf_writer.add_feed_forward_length(4 * n_embd)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head)
|
||||
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"]))
|
||||
|
|
@ -4544,8 +4527,6 @@ class Phi3MiniModel(TextModel):
|
|||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
|
||||
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
|
||||
|
|
@ -4559,7 +4540,7 @@ class Phi3MiniModel(TextModel):
|
|||
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
|
||||
self.gguf_writer.add_embedding_length(n_embd)
|
||||
self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
|
||||
|
|
@ -4679,12 +4660,11 @@ class PlamoModel(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(4096) # not in config.json
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
|
||||
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
||||
|
|
@ -4807,7 +4787,6 @@ class Plamo2Model(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
|
||||
# Which layers are Mamba layers
|
||||
|
|
@ -4819,10 +4798,10 @@ class Plamo2Model(TextModel):
|
|||
num_attention_heads = []
|
||||
|
||||
if mamba_enabled:
|
||||
for i in range(block_count):
|
||||
if block_count <= (mamba_step // 2):
|
||||
for i in range(self.block_count):
|
||||
if self.block_count <= (mamba_step // 2):
|
||||
# use attention in last layer
|
||||
is_mamba = (i != block_count - 1)
|
||||
is_mamba = (i != self.block_count - 1)
|
||||
else:
|
||||
is_mamba = (i % mamba_step) != (mamba_step // 2)
|
||||
if is_mamba:
|
||||
|
|
@ -4840,7 +4819,7 @@ class Plamo2Model(TextModel):
|
|||
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
|
||||
self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128))
|
||||
self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128))
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
|
||||
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
|
||||
|
||||
|
|
@ -4897,12 +4876,10 @@ class CodeShellModel(TextModel):
|
|||
model_arch = gguf.MODEL_ARCH.CODESHELL
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["n_layer"]
|
||||
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
|
|
@ -5044,7 +5021,7 @@ class InternLM2Model(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
|
||||
|
|
@ -5665,11 +5642,10 @@ class GemmaModel(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
|
||||
|
|
@ -5705,11 +5681,10 @@ class Gemma2Model(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
|
||||
|
|
@ -5753,12 +5728,11 @@ class Gemma3Model(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
# some default values are not specified in the hparams
|
||||
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
|
||||
|
|
@ -6034,7 +6008,6 @@ class Rwkv6Model(TextModel):
|
|||
self._set_vocab_rwkv_world()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_size = self.hparams["head_size"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
||||
|
|
@ -6046,7 +6019,7 @@ class Rwkv6Model(TextModel):
|
|||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
|
|
@ -6110,7 +6083,6 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
|||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
num_attention_heads = self.hparams["num_attention_heads"]
|
||||
num_key_value_heads = self.hparams["num_key_value_heads"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
|
|
@ -6123,7 +6095,7 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
|||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
|
||||
self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
|
||||
|
|
@ -6164,7 +6136,6 @@ class Rwkv7Model(TextModel):
|
|||
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
try:
|
||||
head_size = self.hparams["head_size"]
|
||||
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
||||
|
|
@ -6189,7 +6160,7 @@ class Rwkv7Model(TextModel):
|
|||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||
|
|
@ -6283,7 +6254,6 @@ class ARwkv7Model(Rwkv7Model):
|
|||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
head_size = self.hparams["head_size"]
|
||||
rms_norm_eps = self.hparams["rms_norm_eps"]
|
||||
|
|
@ -6300,7 +6270,7 @@ class ARwkv7Model(Rwkv7Model):
|
|||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||
|
|
@ -7524,7 +7494,7 @@ class T5Model(TextModel):
|
|||
self.gguf_writer.add_context_length(n_ctx)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None:
|
||||
self.gguf_writer.add_decoder_block_count(dec_n_layer)
|
||||
self.gguf_writer.add_head_count(self.hparams["num_heads"])
|
||||
|
|
@ -7663,7 +7633,7 @@ class T5EncoderModel(TextModel):
|
|||
self.gguf_writer.add_context_length(n_ctx)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["num_heads"])
|
||||
self.gguf_writer.add_key_length(self.hparams["d_kv"])
|
||||
self.gguf_writer.add_value_length(self.hparams["d_kv"])
|
||||
|
|
@ -7726,7 +7696,7 @@ class JaisModel(TextModel):
|
|||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"])
|
||||
|
|
@ -8068,7 +8038,7 @@ class ChatGLMModel(TextModel):
|
|||
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
|
||||
self.gguf_writer.add_embedding_length(n_embed)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed)))
|
||||
self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"]))
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5))
|
||||
|
|
@ -8150,7 +8120,6 @@ class ExaoneModel(TextModel):
|
|||
num_kv_heads = hparams.get("num_key_value_heads", num_heads)
|
||||
layer_norm_eps = hparams["layer_norm_epsilon"]
|
||||
intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
|
||||
num_layers = hparams["num_layers"]
|
||||
# ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
|
||||
# attention_dropout_rate = hparams["attention_dropout"]
|
||||
# ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
|
||||
|
|
@ -8161,7 +8130,7 @@ class ExaoneModel(TextModel):
|
|||
self.gguf_writer.add_context_length(max_position_embeddings)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||
self.gguf_writer.add_block_count(num_layers)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
if (rope_theta := self.hparams.get("rope_theta")) is not None:
|
||||
|
|
|
|||
|
|
@ -277,10 +277,15 @@ def parse_args() -> argparse.Namespace:
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
|
||||
def load_hparams_from_hf(hf_model_id: str) -> tuple[dict[str, Any], Path | None]:
|
||||
from huggingface_hub import try_to_load_from_cache
|
||||
|
||||
# normally, adapter does not come with base model config, we need to load it from AutoConfig
|
||||
config = AutoConfig.from_pretrained(hf_model_id)
|
||||
return config.to_dict()
|
||||
cache_dir = try_to_load_from_cache(hf_model_id, "config.json")
|
||||
cache_dir = Path(cache_dir).parent if isinstance(cache_dir, str) else None
|
||||
|
||||
return config.to_dict(), cache_dir
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
@ -325,13 +330,13 @@ if __name__ == '__main__':
|
|||
# load base model
|
||||
if base_model_id is not None:
|
||||
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
|
||||
hparams = load_hparams_from_hf(base_model_id)
|
||||
hparams, dir_base_model = load_hparams_from_hf(base_model_id)
|
||||
elif dir_base_model is None:
|
||||
if "base_model_name_or_path" in lparams:
|
||||
model_id = lparams["base_model_name_or_path"]
|
||||
logger.info(f"Loading base model from Hugging Face: {model_id}")
|
||||
try:
|
||||
hparams = load_hparams_from_hf(model_id)
|
||||
hparams, dir_base_model = load_hparams_from_hf(model_id)
|
||||
except OSError as e:
|
||||
logger.error(f"Failed to load base model config: {e}")
|
||||
logger.error("Please try downloading the base model and add its path to --base")
|
||||
|
|
@ -480,6 +485,7 @@ if __name__ == '__main__':
|
|||
dir_lora_model=dir_lora,
|
||||
lora_alpha=alpha,
|
||||
hparams=hparams,
|
||||
remote_hf_model_id=base_model_id,
|
||||
)
|
||||
|
||||
logger.info("Exporting model...")
|
||||
|
|
|
|||
18
docs/ops.md
18
docs/ops.md
|
|
@ -17,12 +17,12 @@ Legend:
|
|||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
|
|
@ -43,9 +43,9 @@ Legend:
|
|||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
|
|
@ -87,7 +87,7 @@ Legend:
|
|||
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
|
|
@ -99,7 +99,7 @@ Legend:
|
|||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
|
@ -107,7 +107,7 @@ Legend:
|
|||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
|
|
@ -116,6 +116,6 @@ Legend:
|
|||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@
|
|||
"Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","NEG","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","STEP","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","STEP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","TANH","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","TANH","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","ELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
|
|
@ -29,18 +29,18 @@
|
|||
"Vulkan0","EXP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","EXPM1","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","EXPM1","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
|
||||
"Vulkan0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan"
|
||||
"Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
|
||||
|
|
@ -89,8 +89,8 @@
|
|||
"Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","NEG","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","STEP","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","STEP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","TANH","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","TANH","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","ELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
|
|
@ -113,18 +113,18 @@
|
|||
"Vulkan0","EXP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","EXPM1","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","EXPM1","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
|
||||
"Vulkan0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan"
|
||||
"Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
|
||||
|
|
@ -5654,7 +5654,7 @@
|
|||
"Vulkan0","SUB","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan"
|
||||
"Vulkan0","MUL","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan"
|
||||
"Vulkan0","DIV","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan"
|
||||
"Vulkan0","ADD1","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan"
|
||||
"Vulkan0","ADD1","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=0.000000,inplace=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=1","support","1","yes","Vulkan"
|
||||
|
|
@ -8632,10 +8632,10 @@
|
|||
"Vulkan0","COS","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
|
||||
"Vulkan0","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
|
||||
"Vulkan0","LEAKY_RELU","type=f16,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
|
||||
"Vulkan0","CEIL","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
|
||||
"Vulkan0","ROUND","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||
"Vulkan0","CEIL","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||
"Vulkan0","ROUND","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SQR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
||||
"Vulkan0","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
||||
"Vulkan0","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||
|
|
@ -8643,10 +8643,10 @@
|
|||
"Vulkan0","COS","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
||||
"Vulkan0","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
|
||||
"Vulkan0","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
||||
"Vulkan0","CEIL","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
||||
"Vulkan0","ROUND","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","CEIL","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","ROUND","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","Vulkan"
|
||||
"Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
|
||||
|
|
@ -8654,10 +8654,10 @@
|
|||
"Vulkan0","COS","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||
"Vulkan0","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"
|
||||
"Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan"
|
||||
"Vulkan0","CEIL","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan"
|
||||
"Vulkan0","ROUND","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||
"Vulkan0","CEIL","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||
"Vulkan0","ROUND","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||
|
|
@ -8665,10 +8665,10 @@
|
|||
"Vulkan0","COS","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"
|
||||
"Vulkan0","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","1","yes","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan"
|
||||
"Vulkan0","CEIL","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan"
|
||||
"Vulkan0","ROUND","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan"
|
||||
"Vulkan0","FLOOR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","CEIL","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","ROUND","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","TRUNC","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,1,1],n_past=5","support","1","yes","Vulkan"
|
||||
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,1],n_past=5","support","1","yes","Vulkan"
|
||||
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,2],n_past=5","support","1","yes","Vulkan"
|
||||
|
|
@ -9478,7 +9478,7 @@
|
|||
"Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","no","Vulkan"
|
||||
"Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","no","Vulkan"
|
||||
"Vulkan0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","1","yes","Vulkan"
|
||||
"Vulkan0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","0","no","Vulkan"
|
||||
"Vulkan0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","1","yes","Vulkan"
|
||||
"Vulkan0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","Vulkan"
|
||||
"Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan"
|
||||
"Vulkan0","CUMSUM","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan"
|
||||
|
|
@ -9487,9 +9487,9 @@
|
|||
"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","Vulkan"
|
||||
"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","Vulkan"
|
||||
"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","Vulkan"
|
||||
"Vulkan0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","0","no","Vulkan"
|
||||
"Vulkan0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","0","no","Vulkan"
|
||||
"Vulkan0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","0","no","Vulkan"
|
||||
"Vulkan0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","1","yes","Vulkan"
|
||||
"Vulkan0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","Vulkan"
|
||||
"Vulkan0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","Vulkan"
|
||||
"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Vulkan"
|
||||
"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Vulkan"
|
||||
|
|
|
|||
|
Can't render this file because it is too large.
|
|
|
@ -4,10 +4,10 @@
|
|||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
|
||||
/**
|
||||
* This the arbitrary data which will be passed to each callback.
|
||||
|
|
@ -37,23 +37,23 @@ static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
|
|||
return u.f;
|
||||
}
|
||||
|
||||
static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
|
||||
static float ggml_get_float_value(const uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
|
||||
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
|
||||
float v;
|
||||
if (type == GGML_TYPE_F16) {
|
||||
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
|
||||
v = ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]);
|
||||
} else if (type == GGML_TYPE_F32) {
|
||||
v = *(float *) &data[i];
|
||||
v = *(const float *) &data[i];
|
||||
} else if (type == GGML_TYPE_I64) {
|
||||
v = (float) *(int64_t *) &data[i];
|
||||
v = (float) *(const int64_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I32) {
|
||||
v = (float) *(int32_t *) &data[i];
|
||||
v = (float) *(const int32_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I16) {
|
||||
v = (float) *(int16_t *) &data[i];
|
||||
v = (float) *(const int16_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I8) {
|
||||
v = (float) *(int8_t *) &data[i];
|
||||
v = (float) *(const int8_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_BF16) {
|
||||
v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]);
|
||||
v = ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -392,9 +392,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
||||
|
||||
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10)
|
||||
elseif (EXTRACTED_NUMBER EQUAL 9)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9)
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
|
||||
else()
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@
|
|||
|
||||
#include "kernels.h"
|
||||
|
||||
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
||||
#define NELEMS(x) (sizeof(x) / sizeof(*x))
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t)>
|
||||
static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
|
||||
|
|
@ -635,6 +635,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
},
|
||||
#endif
|
||||
#endif
|
||||
{ /* Sentinel */ }
|
||||
};
|
||||
|
||||
static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
||||
|
|
@ -803,6 +804,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
|||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
{ /* Sentinel */ }
|
||||
};
|
||||
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
||||
|
|
@ -810,7 +812,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|||
|
||||
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
|
||||
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
|
||||
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
|
||||
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
|
||||
|
|
@ -820,7 +822,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|||
}
|
||||
}
|
||||
if (!kernel) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
|
||||
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
|
||||
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
|
||||
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
|
||||
|
|
@ -830,6 +832,10 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(gemm_gemv_kernels);
|
||||
GGML_UNUSED(gemm_gemv_kernels_q8);
|
||||
GGML_UNUSED(cpu_features);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -840,12 +846,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
|
|||
ggml_kleidiai_kernels * kernels = nullptr;
|
||||
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
|
||||
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
|
||||
kernels = &gemm_gemv_kernels[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(features);
|
||||
#endif
|
||||
|
||||
return kernels;
|
||||
|
|
@ -855,12 +863,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features)
|
|||
ggml_kleidiai_kernels * kernels = nullptr;
|
||||
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
|
||||
if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
|
||||
kernels = &gemm_gemv_kernels_q8[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(features);
|
||||
#endif
|
||||
|
||||
return kernels;
|
||||
|
|
|
|||
|
|
@ -9696,13 +9696,12 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params
|
|||
for (int64_t i00 = 0; i00 < n; ++i00) {
|
||||
float sum = 0.0f;
|
||||
for (int64_t t = 0; t < i00; ++t) {
|
||||
sum += A_batch[i00 * n + t] * X_batch[i01 * n + t];
|
||||
sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
|
||||
}
|
||||
|
||||
const float diag = A_batch[i00 * n + i00];
|
||||
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
|
||||
|
||||
X_batch[i01 * n + i00] = (B_batch[i00 * k + i01] - sum) / diag;
|
||||
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -160,18 +160,18 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
|||
#define GGML_F32xt svfloat32_t
|
||||
#define GGML_F32xt_ZERO svdup_n_f32(0.0f)
|
||||
#define GGML_F32xt_SET1(x) svdup_n_f32(x)
|
||||
#define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a)
|
||||
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
|
||||
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_LOAD_IMPL(pg, a) svld1_f32(pg, a)
|
||||
#define GGML_F32xt_LOAD(a) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, a)
|
||||
#define GGML_F32xt_STORE_IMPL(pg, a, b) svst1_f32(pg, a, b)
|
||||
#define GGML_F32xt_STORE(a, b) GGML_F32xt_STORE_IMPL(DEFAULT_PG, a, b)
|
||||
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
|
||||
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_FMA(a, b, c) GGML_F32xt_FMA_IMPL(DEFAULT_PG, a, b, c)
|
||||
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
|
||||
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_ADD(a, b) GGML_F32xt_ADD_IMPL(DEFAULT_PG, a, b)
|
||||
#define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b)
|
||||
#define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_MUL(a, b) GGML_F32xt_MUL_IMPL(DEFAULT_PG, a, b)
|
||||
#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a)
|
||||
#define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_REDUCE_ONE(a) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, a)
|
||||
#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
|
||||
{ \
|
||||
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \
|
||||
|
|
@ -183,7 +183,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
|||
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \
|
||||
(res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \
|
||||
}
|
||||
#define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||
#define GGML_F32xt_REDUCE(res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
|
||||
GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8)
|
||||
|
||||
#define GGML_F32_VEC GGML_F32xt
|
||||
#define GGML_F32_VEC_ZERO GGML_F32xt_ZERO
|
||||
|
|
@ -206,11 +207,11 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
|||
#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
|
||||
|
||||
#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
|
||||
#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F32Cxt_FMA(a, b, c) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, a, b, c)
|
||||
#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
|
||||
#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F32Cxt_ADD(a, b) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, a, b)
|
||||
#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
|
||||
#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F32Cxt_MUL(a, b) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, a, b)
|
||||
#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
|
||||
|
||||
#define GGML_F16x_VEC GGML_F32Cxt
|
||||
|
|
@ -224,7 +225,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
|||
#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
|
||||
|
||||
#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
|
||||
#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F16xt_REDUCE_ONE(a) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, a)
|
||||
|
||||
#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
|
||||
{ \
|
||||
|
|
@ -234,7 +235,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
|||
__fp16 sum_f16 = svaddv_f16(pg16, sum1); \
|
||||
(res) = (ggml_float) sum_f16; \
|
||||
}
|
||||
#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F16xt_REDUCE_MIXED(res, sum1, sum2, sum3, sum4) \
|
||||
GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, res, sum1, sum2, sum3, sum4)
|
||||
|
||||
// F16 NEON
|
||||
|
||||
|
|
|
|||
|
|
@ -698,60 +698,61 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
|||
}
|
||||
|
||||
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
|
||||
#if defined(GGML_SIMD)
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
const int sve_register_length = svcntb() * 8;
|
||||
const int ggml_f16_epr = sve_register_length / 16;
|
||||
const int ggml_f16_step = 2 * ggml_f16_epr;
|
||||
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
|
||||
const int sve_register_length = svcntb() * 8;
|
||||
const int ggml_f16_epr = sve_register_length / 16;
|
||||
const int ggml_f16_step = 2 * ggml_f16_epr;
|
||||
|
||||
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
||||
const int np = (n & ~(ggml_f16_step - 1));
|
||||
svfloat16_t ay1, ay2;
|
||||
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
||||
const int np = (n & ~(ggml_f16_step - 1));
|
||||
svfloat16_t ay1, ay2;
|
||||
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_MUL(ay1, vx);
|
||||
GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_MUL(ay1, vx);
|
||||
GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
|
||||
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_MUL(ay2, vx);
|
||||
GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_MUL(ay2, vx);
|
||||
GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
|
||||
}
|
||||
// leftovers
|
||||
// maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
|
||||
if (np < n) {
|
||||
svbool_t pg = svwhilelt_b16(np, n);
|
||||
svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
|
||||
svfloat16_t out = svmul_f16_m(pg, hy, vx);
|
||||
svst1_f16(pg, (__fp16 *)(y + np), out);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
|
||||
for (int i = 0, vl; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m2(n - i);
|
||||
vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl);
|
||||
vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl);
|
||||
vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl);
|
||||
vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl);
|
||||
__riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl);
|
||||
}
|
||||
#elif defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
}
|
||||
// leftovers
|
||||
// maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
|
||||
if (np < n) {
|
||||
svbool_t pg = svwhilelt_b16(np, n);
|
||||
svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
|
||||
svfloat16_t out = svmul_f16_m(pg, hy, vx);
|
||||
svst1_f16(pg, (__fp16 *)(y + np), out);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
// todo: RVV impl
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#else
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
}
|
||||
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#endif
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
|
|
|
|||
|
|
@ -384,7 +384,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|||
char * src1_ddc = (char *) src1->data;
|
||||
|
||||
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
||||
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1;
|
||||
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
|
||||
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
|
||||
|
||||
if (src0->type == src1->type && contiguous_srcs) {
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
|
|
|
|||
|
|
@ -3001,6 +3001,10 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
|||
static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
|
||||
const ggml_tensor * view,
|
||||
const ggml_tensor * set_rows) {
|
||||
|
||||
if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) {
|
||||
return false;
|
||||
}
|
||||
// ne3 not tested
|
||||
if (rope->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
|
|
@ -3744,10 +3748,110 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
|
|||
return ctx->description.c_str();
|
||||
}
|
||||
|
||||
#if defined(__linux__)
|
||||
// Helper function to get available memory from /proc/meminfo for UMA systems
|
||||
static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_kb, long * free_swap_kb) {
|
||||
FILE * meminfo_file = nullptr;
|
||||
// 2KB buffer for reading /proc/meminfo since it does not report size info, should be enough
|
||||
const size_t BUFFER_SIZE = 2048;
|
||||
auto file_buffer = std::make_unique<char[]>(BUFFER_SIZE);
|
||||
size_t bytes_read = 0;
|
||||
long huge_tlb_total_pages = -1;
|
||||
long huge_tlb_free_pages = -1;
|
||||
long huge_tlb_page_size = -1;
|
||||
|
||||
if (available_memory_kb == nullptr || free_swap_kb == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
meminfo_file = fopen("/proc/meminfo", "r");
|
||||
if (meminfo_file == nullptr) {
|
||||
GGML_LOG_ERROR("%s: failed to open /proc/meminfo\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read file into buffer
|
||||
bytes_read = fread(file_buffer.get(), 1, BUFFER_SIZE - 1, meminfo_file);
|
||||
fclose(meminfo_file);
|
||||
|
||||
if (bytes_read == 0) {
|
||||
GGML_LOG_ERROR("%s: failed to read from /proc/meminfo\n", __func__);
|
||||
return false;
|
||||
}
|
||||
file_buffer[bytes_read] = '\0';
|
||||
|
||||
*available_memory_kb = -1;
|
||||
*free_swap_kb = -1;
|
||||
|
||||
// Parse the file buffer line by line
|
||||
char * line = file_buffer.get();
|
||||
char * line_next;
|
||||
while (line < file_buffer.get() + bytes_read) {
|
||||
// Find the end of the current line
|
||||
line_next = strchr(line, '\n');
|
||||
if (line_next != nullptr) {
|
||||
*line_next = '\0';
|
||||
line_next++;
|
||||
} else {
|
||||
line_next = file_buffer.get() + bytes_read;
|
||||
}
|
||||
|
||||
long value;
|
||||
if (sscanf(line, "MemAvailable: %ld kB", &value) == 1) {
|
||||
*available_memory_kb = value;
|
||||
} else if (sscanf(line, "SwapFree: %ld kB", &value) == 1) {
|
||||
*free_swap_kb = value;
|
||||
} else if (sscanf(line, "HugePages_Total: %ld", &value) == 1) {
|
||||
huge_tlb_total_pages = value;
|
||||
} else if (sscanf(line, "HugePages_Free: %ld", &value) == 1) {
|
||||
huge_tlb_free_pages = value;
|
||||
} else if (sscanf(line, "Hugepagesize: %ld kB", &value) == 1) {
|
||||
huge_tlb_page_size = value;
|
||||
}
|
||||
|
||||
line = line_next;
|
||||
}
|
||||
|
||||
if (huge_tlb_total_pages != 0 && huge_tlb_total_pages != -1) {
|
||||
*available_memory_kb = huge_tlb_free_pages * huge_tlb_page_size;
|
||||
|
||||
// Hugetlbfs pages are not swappable.
|
||||
*free_swap_kb = 0;
|
||||
}
|
||||
|
||||
GGML_LOG_DEBUG("%s: final available_memory_kb: %ld\n", __func__, *available_memory_kb);
|
||||
return true;
|
||||
}
|
||||
#endif // defined(__linux__)
|
||||
|
||||
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
ggml_cuda_set_device(ctx->device);
|
||||
CUDA_CHECK(cudaMemGetInfo(free, total));
|
||||
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/17368
|
||||
#if defined(__linux__)
|
||||
// Check if this is a UMA (Unified Memory Architecture) system
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device));
|
||||
|
||||
// Check if UMA is explicitly enabled via environment variable
|
||||
bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
|
||||
bool is_uma = prop.unifiedAddressing > 0 || uma_env;
|
||||
|
||||
if (is_uma) {
|
||||
// For UMA systems (like DGX Spark), use system memory info
|
||||
long available_memory_kb = 0;
|
||||
long free_swap_kb = 0;
|
||||
|
||||
if (ggml_backend_cuda_get_available_uma_memory(&available_memory_kb, &free_swap_kb) && available_memory_kb > 0) {
|
||||
*free = (size_t)available_memory_kb * 1024;
|
||||
} else {
|
||||
GGML_LOG_ERROR("%s: /proc/meminfo reading failed, using cudaMemGetInfo\n", __func__);
|
||||
}
|
||||
}
|
||||
#endif // defined(__linux__)
|
||||
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
#include <cassert>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <cmath>
|
||||
|
||||
static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
|
||||
if (!t) {
|
||||
|
|
|
|||
|
|
@ -406,8 +406,8 @@ enum shader_reduction_mode {
|
|||
SHADER_REDUCTION_MODE_COUNT,
|
||||
};
|
||||
|
||||
// argsort pipelines for up to 1<<10 invocations per workgroup
|
||||
static constexpr uint32_t num_argsort_pipelines = 11;
|
||||
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
||||
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
||||
|
||||
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
|
|
@ -526,6 +526,7 @@ struct vk_device_struct {
|
|||
bool multi_add;
|
||||
bool shader_int64;
|
||||
bool buffer_device_address;
|
||||
bool vulkan_memory_model;
|
||||
|
||||
bool add_rms_fusion;
|
||||
uint32_t partials_binding_alignment;
|
||||
|
|
@ -539,6 +540,9 @@ struct vk_device_struct {
|
|||
uint32_t subgroup_max_size;
|
||||
bool subgroup_require_full_support;
|
||||
|
||||
// floor(log2(maxComputeWorkGroupInvocations))
|
||||
uint32_t max_workgroup_size_log2 {};
|
||||
|
||||
bool coopmat_support;
|
||||
bool coopmat_acc_f32_support {};
|
||||
bool coopmat_acc_f16_support {};
|
||||
|
|
@ -638,6 +642,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
|
||||
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32;
|
||||
vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_norm_f32;
|
||||
|
|
@ -664,6 +669,20 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_hardsigmoid[2];
|
||||
vk_pipeline pipeline_hardswish[2];
|
||||
vk_pipeline pipeline_abs[2];
|
||||
vk_pipeline pipeline_softplus[2];
|
||||
vk_pipeline pipeline_step[2];
|
||||
vk_pipeline pipeline_round[2];
|
||||
vk_pipeline pipeline_ceil[2];
|
||||
vk_pipeline pipeline_floor[2];
|
||||
vk_pipeline pipeline_trunc[2];
|
||||
|
||||
vk_pipeline pipeline_add1_f16_f16;
|
||||
vk_pipeline pipeline_add1_f16_f32;
|
||||
vk_pipeline pipeline_add1_f32_f32;
|
||||
|
||||
vk_pipeline pipeline_arange_f32;
|
||||
|
||||
vk_pipeline pipeline_fill_f32;
|
||||
|
||||
vk_pipeline pipeline_geglu[2];
|
||||
vk_pipeline pipeline_reglu[2];
|
||||
|
|
@ -683,6 +702,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
|
||||
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
||||
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
||||
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
||||
vk_pipeline pipeline_sum_rows_f32;
|
||||
vk_pipeline pipeline_argmax_f32;
|
||||
vk_pipeline pipeline_count_equal_i32;
|
||||
|
|
@ -1173,8 +1193,14 @@ struct vk_op_soft_max_push_constants {
|
|||
|
||||
struct vk_op_argsort_push_constants {
|
||||
uint32_t ncols;
|
||||
uint32_t ncols_padded;
|
||||
uint32_t ncols_padded_log2;
|
||||
uint32_t nrows;
|
||||
int32_t order;
|
||||
uint32_t order;
|
||||
uint32_t outer_start;
|
||||
uint32_t outer_end;
|
||||
uint32_t inner_start;
|
||||
uint32_t inner_end;
|
||||
};
|
||||
|
||||
struct vk_op_im2col_push_constants {
|
||||
|
|
@ -2901,15 +2927,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
if (path == FAPATH) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
|
@ -3697,6 +3723,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
|
|
@ -3826,6 +3855,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
CREATE_UNARY(hardsigmoid)
|
||||
CREATE_UNARY(hardswish)
|
||||
CREATE_UNARY(abs)
|
||||
CREATE_UNARY(softplus)
|
||||
CREATE_UNARY(step)
|
||||
CREATE_UNARY(round)
|
||||
CREATE_UNARY(ceil)
|
||||
CREATE_UNARY(floor)
|
||||
CREATE_UNARY(trunc)
|
||||
#undef CREATE_UNARY
|
||||
|
||||
#define CREATE_UNARY_RTE(name) \
|
||||
|
|
@ -3839,6 +3874,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
CREATE_UNARY_RTE(exp)
|
||||
#undef CREATE_UNARY_RTE
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
#define CREATE_GLU(name) \
|
||||
if (device->float_controls_rte_fp16) { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
|
|
@ -3891,7 +3934,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
}
|
||||
|
||||
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
|
||||
uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
|
||||
if (i <= device->max_workgroup_size_log2 &&
|
||||
2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
|
||||
const uint32_t NCOLS_PADDED_LOG2 = i;
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
|
||||
}
|
||||
const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1;
|
||||
BLOCK_SIZE /= WG_UNROLL_FACTOR;
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
|
@ -4292,6 +4343,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
|
||||
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
|
||||
|
||||
device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
|
||||
|
||||
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
||||
|
||||
// Try to find a non-graphics compute queue and transfer-focused queues
|
||||
|
|
@ -4431,6 +4484,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
|
||||
device->shader_int64 = device_features2.features.shaderInt64;
|
||||
device->buffer_device_address = vk12_features.bufferDeviceAddress;
|
||||
device->vulkan_memory_model = vk12_features.vulkanMemoryModel;
|
||||
|
||||
if (device->subgroup_size_control) {
|
||||
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
|
||||
|
|
@ -6247,6 +6301,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
|||
// Choose "contiguous copy" shader if src/dst are contiguous
|
||||
bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst));
|
||||
|
||||
// Use optimized "transpose" shader if src dim1 is the innermost dimension.
|
||||
bool transpose = dst && src->nb[1] == ggml_type_size(to) && ggml_are_same_shape(dst, src);
|
||||
|
||||
if (transpose && src->type == to) {
|
||||
if (ggml_type_size(to) == 4) {
|
||||
return ctx->device->pipeline_cpy_transpose_32;
|
||||
} else if (ggml_type_size(to) == 2) {
|
||||
return ctx->device->pipeline_cpy_transpose_16;
|
||||
}
|
||||
}
|
||||
|
||||
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
|
||||
if (contig) {
|
||||
return ctx->device->pipeline_contig_cpy_f32_f32;
|
||||
|
|
@ -8242,6 +8307,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_ABS:
|
||||
return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_STEP:
|
||||
return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
return ctx->device->pipeline_round[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
return ctx->device->pipeline_ceil[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -8344,19 +8421,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
case GGML_OP_ARGSORT:
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||
return ctx->device->pipeline_topk_moe[idx][mode];
|
||||
}
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
return ctx->device->pipeline_argsort_f32[idx];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
|
|
@ -8449,7 +8513,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
||||
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
||||
std::array<uint32_t, 3> elements;
|
||||
std::array<uint32_t, 3> elements{};
|
||||
if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst);
|
||||
else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst);
|
||||
vk_conv_shapes shape;
|
||||
|
|
@ -8527,6 +8591,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
}
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ADD1:
|
||||
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
return ctx->device->pipeline_add1_f16_f16;
|
||||
}
|
||||
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
|
||||
return ctx->device->pipeline_add1_f16_f32;
|
||||
}
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_add1_f32_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ARANGE:
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_arange_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_FILL:
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_fill_f32;
|
||||
}
|
||||
return nullptr;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -8748,8 +8833,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
GGML_ASSERT(0);
|
||||
break;
|
||||
case GGML_OP_IM2COL:
|
||||
{
|
||||
|
|
@ -8817,6 +8901,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
|
|
@ -8858,6 +8945,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
} else {
|
||||
elements = { ne, 1, 1 };
|
||||
}
|
||||
|
||||
if (pipeline == ctx->device->pipeline_cpy_transpose_32 ||
|
||||
pipeline == ctx->device->pipeline_cpy_transpose_16) {
|
||||
// 32x32 tiles
|
||||
elements[0] = (uint32_t)CEIL_DIV(dst->ne[0], 32);
|
||||
elements[1] = (uint32_t)CEIL_DIV(dst->ne[1], 32);
|
||||
elements[2] = (uint32_t)(dst->ne[2]*dst->ne[3]);
|
||||
elements[0] = std::min(elements[0], ctx->device->properties.limits.maxComputeWorkGroupCount[0]);
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ADD_ID:
|
||||
{
|
||||
|
|
@ -9423,6 +9521,63 @@ static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst));
|
||||
}
|
||||
|
||||
static void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD1, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f, 0,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
VK_LOG_DEBUG("ggml_vk_arange(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
|
||||
|
||||
vk_op_push_constants pc = {
|
||||
(uint32_t)ggml_nelements(dst),
|
||||
1,
|
||||
ggml_get_op_params_f32(dst, 0),
|
||||
ggml_get_op_params_f32(dst, 2),
|
||||
};
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
|
||||
GGML_ASSERT(pipeline != nullptr);
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
|
||||
|
||||
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
VK_LOG_DEBUG("ggml_vk_fill(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
|
||||
|
||||
vk_op_push_constants pc = {
|
||||
(uint32_t)ggml_nelements(dst),
|
||||
1,
|
||||
ggml_get_op_params_f32(dst, 0),
|
||||
0.0f,
|
||||
};
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
|
||||
GGML_ASSERT(pipeline != nullptr);
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
|
||||
|
||||
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst));
|
||||
}
|
||||
|
|
@ -9865,16 +10020,89 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|||
}
|
||||
|
||||
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
int32_t * op_params = (int32_t *)dst->op_params;
|
||||
const uint32_t * op_params = (const uint32_t *)dst->op_params;
|
||||
|
||||
uint32_t ncols = src0->ne[0];
|
||||
uint32_t nrows = ggml_nrows(src0);
|
||||
|
||||
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
|
||||
ncols,
|
||||
nrows,
|
||||
op_params[0],
|
||||
});
|
||||
uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols)));
|
||||
uint32_t ncolsp2 = 1 << ncols_pad_log2;
|
||||
|
||||
vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, };
|
||||
|
||||
// Pick the largest workgroup size <= ncolsp2
|
||||
uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1);
|
||||
|
||||
// Use the "small" argsort shader if the whole sort can be done by a single workgroup.
|
||||
bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 &&
|
||||
ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr;
|
||||
|
||||
vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx]
|
||||
: ctx->device->pipeline_argsort_large_f32[pipeline_idx];
|
||||
|
||||
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0);
|
||||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||
vk_subbuffer subbuf1 = dst_buf;
|
||||
|
||||
// Reserve space for ivec2 per element, with rows padded to a power of two
|
||||
if (!use_small) {
|
||||
const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int);
|
||||
|
||||
if (ctx->prealloc_size_x < x_sz) {
|
||||
ctx->prealloc_size_x = x_sz;
|
||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
if (ctx->prealloc_x_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
|
||||
}
|
||||
|
||||
std::array<uint32_t, 3> elements;
|
||||
|
||||
elements[0] = ncolsp2;
|
||||
elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
elements[2] = 1;
|
||||
|
||||
// First dispatch initializes tmp_idx and does the first N passes where
|
||||
// there is only communication between threads in the same workgroup.
|
||||
{
|
||||
vk_op_argsort_push_constants pc2 = pc;
|
||||
pc2.outer_start = 0;
|
||||
pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2);
|
||||
pc2.inner_start = 0;
|
||||
pc2.inner_end = 100;
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
|
||||
}
|
||||
if (!use_small) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
// Loop over outer/inner passes, synchronizing between each pass.
|
||||
for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) {
|
||||
for (uint32_t inner = 0; inner < outer + 1; ++inner) {
|
||||
vk_op_argsort_push_constants pc2 = pc;
|
||||
pc2.outer_start = outer;
|
||||
pc2.outer_end = outer + 1;
|
||||
pc2.inner_start = inner;
|
||||
pc2.inner_end = inner + 1;
|
||||
// When the inner idx is large enough, there's only communication
|
||||
// within a workgroup. So the remaining inner iterations can all
|
||||
// run in the same dispatch.
|
||||
if (outer - inner < pipeline_idx) {
|
||||
pc2.inner_end = 100;
|
||||
inner = outer;
|
||||
pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx];
|
||||
} else {
|
||||
// Smaller workgroup empirically seems to perform better
|
||||
pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2];
|
||||
}
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
}
|
||||
ctx->prealloc_x_need_sync = true;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
|
|
@ -11182,6 +11410,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_STEP:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
|
|
@ -11223,6 +11457,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_SCALE:
|
||||
|
|
@ -11435,6 +11672,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_OP_UPSCALE:
|
||||
ggml_vk_upscale(ctx, compute_ctx, src0, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_ADD1:
|
||||
ggml_vk_add1(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_ARANGE:
|
||||
ggml_vk_arange(ctx, compute_ctx, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_FILL:
|
||||
ggml_vk_fill(ctx, compute_ctx, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_SCALE:
|
||||
ggml_vk_scale(ctx, compute_ctx, src0, node);
|
||||
|
|
@ -11519,6 +11768,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_STEP:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
ggml_vk_unary(ctx, compute_ctx, src0, node);
|
||||
break;
|
||||
default:
|
||||
|
|
@ -11721,6 +11976,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_ADD_ID:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_UPSCALE:
|
||||
|
|
@ -11792,6 +12050,12 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_STEP:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
buf = tensor->buffer;
|
||||
break;
|
||||
default:
|
||||
|
|
@ -13394,6 +13658,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_STEP:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
return ggml_is_contiguous(op->src[0]) &&
|
||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||
|
|
@ -13695,10 +13965,25 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_OP_LOG:
|
||||
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_ARGSORT:
|
||||
return op->ne[0] <= max_argsort_cols;
|
||||
{
|
||||
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
|
||||
return false;
|
||||
}
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
// pipeline_argsort_large_f32 requires vulkan memory model.
|
||||
if (device->vulkan_memory_model) {
|
||||
return true;
|
||||
} else {
|
||||
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
|
||||
}
|
||||
}
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
|
|
@ -14181,6 +14466,16 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
} else if (tensor->op == GGML_OP_SCALE) {
|
||||
const float * params = (const float *)tensor->op_params;
|
||||
tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
|
||||
} else if (tensor->op == GGML_OP_ADD1) {
|
||||
tensor_clone = ggml_add1(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
} else if (tensor->op == GGML_OP_ARANGE) {
|
||||
const float start = ggml_get_op_params_f32(tensor, 0);
|
||||
const float stop = ggml_get_op_params_f32(tensor, 1);
|
||||
const float step = ggml_get_op_params_f32(tensor, 2);
|
||||
tensor_clone = ggml_arange(ggml_ctx, start, stop, step);
|
||||
} else if (tensor->op == GGML_OP_FILL) {
|
||||
const float value = ggml_get_op_params_f32(tensor, 0);
|
||||
tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value);
|
||||
} else if (tensor->op == GGML_OP_SQR) {
|
||||
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
|
||||
} else if (tensor->op == GGML_OP_SQRT) {
|
||||
|
|
@ -14294,6 +14589,24 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
case GGML_UNARY_OP_ABS:
|
||||
tensor_clone = ggml_abs(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
tensor_clone = ggml_softplus(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
case GGML_UNARY_OP_STEP:
|
||||
tensor_clone = ggml_step(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
tensor_clone = ggml_round(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
tensor_clone = ggml_ceil(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
tensor_clone = ggml_floor(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
default:
|
||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_binary_head.glsl"
|
||||
|
||||
const uint num_threads = 256;
|
||||
|
||||
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
uint idx = get_idx();
|
||||
|
||||
const uint num_iter = 2;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||
if (idx >= p.ne) {
|
||||
continue;
|
||||
}
|
||||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset()]));
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
// p.param1 = start, p.param2 = step
|
||||
float value = p.param1 + p.param2 * float(i);
|
||||
data_d[i] = D_TYPE(value);
|
||||
}
|
||||
|
|
@ -4,28 +4,27 @@
|
|||
#include "types.glsl"
|
||||
|
||||
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
|
||||
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
|
||||
layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
|
||||
#define ASC 0
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) buffer D {int data_d[];};
|
||||
layout (binding = 2) writeonly buffer D {int data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint ncols;
|
||||
uint ncols_padded;
|
||||
uint ncols_padded_log2;
|
||||
uint nrows;
|
||||
uint order;
|
||||
uint outer_start;
|
||||
uint outer_end;
|
||||
uint inner_start;
|
||||
uint inner_end;
|
||||
} p;
|
||||
|
||||
shared int dst_row[BLOCK_SIZE];
|
||||
shared A_TYPE a_sh[BLOCK_SIZE];
|
||||
|
||||
void swap(uint idx0, uint idx1) {
|
||||
int tmp = dst_row[idx0];
|
||||
dst_row[idx0] = dst_row[idx1];
|
||||
dst_row[idx1] = tmp;
|
||||
}
|
||||
shared ivec2 dst_row[BLOCK_SIZE];
|
||||
|
||||
void argsort(bool needs_bounds_check, const uint row) {
|
||||
// bitonic sort
|
||||
|
|
@ -34,11 +33,10 @@ void argsort(bool needs_bounds_check, const uint row) {
|
|||
const uint row_offset = row * p.ncols;
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
a_sh[col] = data_a[row_offset + col];
|
||||
dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col]));
|
||||
barrier();
|
||||
|
||||
uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
|
||||
uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
|
||||
[[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
|
||||
uint num_inner_loop_iters = outer_idx + 1;
|
||||
[[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
|
||||
|
|
@ -47,14 +45,15 @@ void argsort(bool needs_bounds_check, const uint row) {
|
|||
int idx_0 = (col & k) == 0 ? col : ixj;
|
||||
int idx_1 = (col & k) == 0 ? ixj : col;
|
||||
|
||||
int sh_idx_0 = dst_row[idx_0];
|
||||
int sh_idx_1 = dst_row[idx_1];
|
||||
bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
|
||||
bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;
|
||||
ivec2 sh_idx_0 = dst_row[idx_0];
|
||||
ivec2 sh_idx_1 = dst_row[idx_1];
|
||||
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
|
||||
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
|
||||
|
||||
if ((idx_0_oob ||
|
||||
(!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
|
||||
swap(idx_0, idx_1);
|
||||
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
|
||||
dst_row[idx_0] = sh_idx_1;
|
||||
dst_row[idx_1] = sh_idx_0;
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
|
@ -63,9 +62,9 @@ void argsort(bool needs_bounds_check, const uint row) {
|
|||
|
||||
if (col < p.ncols) {
|
||||
if (p.order == ASC) {
|
||||
data_d[row_offset + col] = dst_row[col];
|
||||
data_d[row_offset + col] = dst_row[col].x;
|
||||
} else {
|
||||
data_d[row_offset + p.ncols - col - 1] = dst_row[col];
|
||||
data_d[row_offset + p.ncols - col - 1] = dst_row[col].x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,114 @@
|
|||
#version 450
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#pragma use_vulkan_memory_model
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
|
||||
layout(constant_id = 1) const int WG_UNROLL_FACTOR = 2;
|
||||
#define ASC 0
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) workgroupcoherent buffer B {ivec2 tmp_idx[];};
|
||||
layout (binding = 2) workgroupcoherent buffer D {int data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint ncols;
|
||||
uint ncols_padded;
|
||||
uint ncols_padded_log2;
|
||||
uint nrows;
|
||||
uint order;
|
||||
uint outer_start;
|
||||
uint outer_end;
|
||||
uint inner_start;
|
||||
uint inner_end;
|
||||
} p;
|
||||
|
||||
void argsort(bool needs_bounds_check, const uint row) {
|
||||
// bitonic sort
|
||||
int col = int(gl_GlobalInvocationID.x);
|
||||
col = (col % BLOCK_SIZE) + (col / BLOCK_SIZE) * BLOCK_SIZE * WG_UNROLL_FACTOR;
|
||||
|
||||
const uint row_offset = row * p.ncols;
|
||||
uint idx_offset = row * p.ncols_padded;
|
||||
|
||||
bool need_barrier = false;
|
||||
|
||||
// initialize indices
|
||||
if (p.outer_start == 0 && p.inner_start == 0) {
|
||||
[[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
|
||||
uint c = u*BLOCK_SIZE + col;
|
||||
if (c < p.ncols_padded) {
|
||||
ivec2 v = ivec2(c, floatBitsToInt(data_a[row_offset + c]));
|
||||
tmp_idx[idx_offset + c] = v;
|
||||
}
|
||||
}
|
||||
need_barrier = true;
|
||||
}
|
||||
|
||||
[[unroll]] for (uint outer_idx = p.outer_start, k = (2 << outer_idx); outer_idx < p.outer_end; k *= 2, outer_idx++) {
|
||||
uint inner_end = min(p.inner_end, outer_idx + 1);
|
||||
for (uint j = k >> (p.inner_start + 1), inner_idx = p.inner_start; inner_idx < inner_end; j /= 2, inner_idx++) {
|
||||
if (need_barrier) {
|
||||
controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);
|
||||
}
|
||||
need_barrier = true;
|
||||
[[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
|
||||
int c = u*BLOCK_SIZE + col;
|
||||
const int ixj = int(c ^ j);
|
||||
|
||||
if (ixj < c) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int idx_0 = (c & k) == 0 ? c : ixj;
|
||||
int idx_1 = (c & k) == 0 ? ixj : c;
|
||||
|
||||
ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0];
|
||||
ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1];
|
||||
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
|
||||
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
|
||||
|
||||
if ((idx_0_oob ||
|
||||
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) {
|
||||
tmp_idx[idx_offset + idx_0] = sh_idx_1;
|
||||
tmp_idx[idx_offset + idx_1] = sh_idx_0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (p.outer_end == p.ncols_padded_log2 &&
|
||||
p.inner_end >= p.ncols_padded_log2 + 1) {
|
||||
controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);
|
||||
[[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
|
||||
uint c = u*BLOCK_SIZE + col;
|
||||
if (c < p.ncols) {
|
||||
if (p.order == ASC) {
|
||||
data_d[row_offset + c] = tmp_idx[idx_offset + c].x;
|
||||
} else {
|
||||
data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
if (p.ncols == p.ncols_padded) {
|
||||
uint row = gl_WorkGroupID.y;
|
||||
while (row < p.nrows) {
|
||||
argsort(false, row);
|
||||
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||
}
|
||||
} else {
|
||||
uint row = gl_WorkGroupID.y;
|
||||
while (row < p.nrows) {
|
||||
argsort(true, row);
|
||||
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float x = float(data_a[i]);
|
||||
data_d[i] = D_TYPE(ceil(x));
|
||||
}
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
// workgroup does 32x32 tile, but uses 32x8 threads
|
||||
#define TILE_DIM 32
|
||||
layout(local_size_x = 32, local_size_y = 8, local_size_z = 1) in;
|
||||
|
||||
shared uint sh[TILE_DIM][TILE_DIM + 1];
|
||||
|
||||
void iter(uvec3 wg_id) {
|
||||
const uint tile_col = wg_id.x;
|
||||
const uint tile_row = wg_id.y;
|
||||
|
||||
const uint tid_col = gl_LocalInvocationID.x;
|
||||
const uint tid_row = gl_LocalInvocationID.y;
|
||||
|
||||
const uint i2 = wg_id.z % p.ne12;
|
||||
const uint i3 = wg_id.z / p.ne12;
|
||||
const uint i02 = i2;
|
||||
const uint i03 = i3;
|
||||
|
||||
// The workgroup does TILE_DIM x TILE_DIM, but swaps the LSBs of the
|
||||
// src coords to make memory accesses contiguous, dst has tid.x in i0,
|
||||
// src has tid.x in i01
|
||||
|
||||
[[unroll]] for (uint y = 0; y < 4; ++y) {
|
||||
const uint i00 = tile_col * TILE_DIM + tid_row + 8 * y;
|
||||
const uint i01 = tile_row * TILE_DIM + tid_col;
|
||||
if (i00 < p.ne00 && i01 < p.ne01 && i02 < p.ne02 && i03 < p.ne03) {
|
||||
const uint src_idx = i00 * p.nb00 + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
|
||||
sh[tid_row + 8 * y][tid_col] = uint(data_a[get_aoffset() + src_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint y = 0; y < 4; ++y) {
|
||||
const uint i0 = tile_col * TILE_DIM + tid_col;
|
||||
const uint i1 = tile_row * TILE_DIM + tid_row + 8 * y;
|
||||
if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
|
||||
const uint dst_idx = i0 * p.nb10 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
|
||||
// load transposed
|
||||
data_d[get_doffset() + dst_idx] = D_TYPE(sh[tid_col][tid_row + 8 * y]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
void main() {
|
||||
uint z = gl_WorkGroupID.z;
|
||||
uint y = gl_WorkGroupID.y;
|
||||
bool need_barrier = false;
|
||||
for (uint z = gl_WorkGroupID.z; z < p.ne12 * p.ne13; z += gl_NumWorkGroups.z) {
|
||||
for (uint y = gl_WorkGroupID.y; y < CEIL_DIV(p.ne11, TILE_DIM); y += gl_NumWorkGroups.y) {
|
||||
for (uint x = gl_WorkGroupID.x; x < CEIL_DIV(p.ne10, TILE_DIM); x += gl_NumWorkGroups.x) {
|
||||
if (need_barrier) {
|
||||
barrier();
|
||||
}
|
||||
need_barrier = true;
|
||||
iter(uvec3(x, y, z));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
// p.param1 = fill value
|
||||
data_d[i] = D_TYPE(p.param1);
|
||||
}
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float x = float(data_a[i]);
|
||||
data_d[i] = D_TYPE(floor(x));
|
||||
}
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float x = float(data_a[i]);
|
||||
float result;
|
||||
// Round halfway cases away from zero as roundf does.
|
||||
if (x >= 0.0) {
|
||||
result = floor(x + 0.5);
|
||||
} else {
|
||||
result = ceil(x - 0.5);
|
||||
}
|
||||
data_d[i] = D_TYPE(result);
|
||||
}
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float x = float(data_a[i]);
|
||||
const float result = (x > 20.0f) ? x : log(1.0f + exp(x));
|
||||
data_d[i] = D_TYPE(result);
|
||||
}
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float x = float(data_a[i]);
|
||||
data_d[i] = D_TYPE(x >= 0.0f ? 1.0f : 0.0f);
|
||||
}
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float x = float(data_a[i]);
|
||||
data_d[i] = D_TYPE(trunc(x));
|
||||
}
|
||||
|
|
@ -734,6 +734,9 @@ void process_shaders() {
|
|||
string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
|
||||
string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}});
|
||||
string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}});
|
||||
|
||||
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
||||
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
|
|
@ -843,6 +846,25 @@ void process_shaders() {
|
|||
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("add1_f16_f16", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("add1_f16_f32", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("round_f32", "round.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("ceil_f16", "ceil.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("ceil_f32", "ceil.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("floor_f16", "floor.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("floor_f32", "floor.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
for (auto rte : {false, true}) {
|
||||
std::string suffix = rte ? "_rte" : "";
|
||||
string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
|
|
@ -889,6 +911,7 @@ void process_shaders() {
|
|||
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
||||
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
|
||||
|
||||
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
||||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
7b6abb2b92fcef35cb01c6ce6ada9bd85306522d
|
||||
781baf2a14d9e0aaee542b2e1bb918bfc4132199
|
||||
|
|
|
|||
|
|
@ -6,8 +6,10 @@
|
|||
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
|
||||
#define MAX_REPETITION_THRESHOLD 2000
|
||||
//
|
||||
// helpers
|
||||
//
|
||||
|
|
@ -345,7 +347,9 @@ const char * llama_grammar_parser::parse_sequence(
|
|||
size_t last_sym_start = rule.size();
|
||||
const char * pos = src;
|
||||
|
||||
auto handle_repetitions = [&](int min_times, int max_times) {
|
||||
// use UINT64_MAX as the empty value because we aligned to the proper unsigned long type so -1 can't be used
|
||||
// (though it's technically the same as -1 now)
|
||||
auto handle_repetitions = [&](unsigned long min_times, unsigned long max_times) {
|
||||
|
||||
if (last_sym_start == rule.size()) {
|
||||
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
||||
|
|
@ -373,20 +377,20 @@ const char * llama_grammar_parser::parse_sequence(
|
|||
rule.resize(last_sym_start);
|
||||
} else {
|
||||
// Repeat the previous elements (min_times - 1) times
|
||||
for (int i = 1; i < min_times; i++) {
|
||||
for (unsigned long i = 1; i < min_times; i++) {
|
||||
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t last_rec_rule_id = 0;
|
||||
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
||||
auto n_opt = max_times == UINT64_MAX ? 1 : max_times - min_times;
|
||||
|
||||
llama_grammar_rule rec_rule(prev_rule);
|
||||
for (int i = 0; i < n_opt; i++) {
|
||||
for (unsigned long i = 0; i < n_opt; i++) {
|
||||
rec_rule.resize(prev_rule.size());
|
||||
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
||||
if (i > 0 || max_times < 0) {
|
||||
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
|
||||
if (i > 0 || max_times == UINT64_MAX) {
|
||||
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times == UINT64_MAX ? rec_rule_id : last_rec_rule_id});
|
||||
}
|
||||
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||
|
|
@ -478,10 +482,10 @@ const char * llama_grammar_parser::parse_sequence(
|
|||
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
||||
}
|
||||
const char * int_end = parse_int(pos);
|
||||
int min_times = std::stoul(std::string(pos, int_end - pos));
|
||||
unsigned long min_times = std::stoul(std::string(pos, int_end - pos));
|
||||
pos = parse_space(int_end, is_nested);
|
||||
|
||||
int max_times = -1;
|
||||
unsigned long max_times = UINT64_MAX;
|
||||
|
||||
if (*pos == '}') {
|
||||
max_times = min_times;
|
||||
|
|
@ -502,6 +506,9 @@ const char * llama_grammar_parser::parse_sequence(
|
|||
} else {
|
||||
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
||||
}
|
||||
if (min_times > MAX_REPETITION_THRESHOLD || (max_times != UINT64_MAX && max_times > MAX_REPETITION_THRESHOLD)) {
|
||||
throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
|
||||
}
|
||||
handle_repetitions(min_times, max_times);
|
||||
} else {
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -20,10 +20,10 @@ static llama_logger_state g_logger_state;
|
|||
time_meas::time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||
|
||||
time_meas::~time_meas() {
|
||||
if (t_start_us >= 0) {
|
||||
t_acc += ggml_time_us() - t_start_us;
|
||||
}
|
||||
if (t_start_us >= 0) {
|
||||
t_acc += ggml_time_us() - t_start_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_log_set(ggml_log_callback log_callback, void * user_data) {
|
||||
ggml_log_set(log_callback, user_data);
|
||||
|
|
|
|||
|
|
@ -472,9 +472,6 @@ static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
|
|||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_reset(smpl);
|
||||
}
|
||||
|
||||
chain->t_sample_us = 0;
|
||||
chain->n_sample = 0;
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
|
||||
|
|
@ -2670,8 +2667,7 @@ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * c
|
|||
void llama_perf_sampler_print(const struct llama_sampler * chain) {
|
||||
const auto data = llama_perf_sampler(chain);
|
||||
|
||||
LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
|
||||
LLAMA_LOG_INFO("%s: samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample);
|
||||
}
|
||||
|
||||
void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
||||
|
|
@ -2681,5 +2677,6 @@ void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
|||
|
||||
auto * ctx = (struct llama_sampler_chain *) chain->ctx;
|
||||
|
||||
ctx->t_sample_us = ctx->n_sample = 0;
|
||||
ctx->t_sample_us = 0;
|
||||
ctx->n_sample = 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2776,24 +2776,34 @@ struct test_cpy : public test_case {
|
|||
struct test_cont : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
bool use_view_slice;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
return VARS_TO_STR3(type, ne, use_view_slice);
|
||||
}
|
||||
|
||||
test_cont(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 10, 10, 1})
|
||||
: type(type), ne(ne) {}
|
||||
std::array<int64_t, 4> ne = {10, 10, 10, 1},
|
||||
bool use_view_slice = false)
|
||||
: type(type), ne(ne), use_view_slice(use_view_slice) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_param(src);
|
||||
ggml_set_name(src, "src");
|
||||
|
||||
src = ggml_transpose(ctx, src);
|
||||
ggml_set_name(src, "src_transposed");
|
||||
|
||||
ggml_tensor * out = ggml_cont(ctx, src);
|
||||
ggml_tensor * dst;
|
||||
if (use_view_slice) {
|
||||
dst = ggml_view_4d(ctx, src, src->ne[0], 1, src->ne[2], src->ne[3],
|
||||
src->nb[1], src->nb[2], src->nb[3], src->nb[0] * (src->ne[1] - 1));
|
||||
ggml_set_name(dst, "src_view_slice");
|
||||
} else {
|
||||
dst = ggml_transpose(ctx, src);
|
||||
ggml_set_name(dst, "src_transposed");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_cont(ctx, dst);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
|
|
@ -6945,16 +6955,17 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
|
||||
|
||||
test_cases.emplace_back(new test_cont());
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 3 ,5}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 3, 5 ,7}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 1 ,1}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 3 ,5}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 3, 5 ,7}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 1 ,1}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 3 ,5}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
|
||||
for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
|
||||
for (bool use_view_slice : { true, false }) {
|
||||
for (std::array<int64_t, 4> ne : std::initializer_list<std::array<int64_t, 4>>{ {2, 1, 1, 1}, {2, 1, 3, 5},
|
||||
{2, 3, 5, 7}, {1, 4, 4, 1}, {1, 8, 17, 1}, {10, 10, 10, 1} }) {
|
||||
if (use_view_slice && (type_dst == GGML_TYPE_F16 || type_dst == GGML_TYPE_BF16)) {
|
||||
continue; // TODO: add after WebGPU is fixed
|
||||
}
|
||||
test_cases.emplace_back(new test_cont(type_dst, ne, use_view_slice));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
|
||||
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
|
||||
|
|
@ -7015,6 +7026,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
||||
|
||||
test_cases.emplace_back(new test_add1());
|
||||
test_cases.emplace_back(new test_add1(GGML_TYPE_F32, {1024, 1024, 1, 1}));
|
||||
test_cases.emplace_back(new test_scale());
|
||||
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
|
||||
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test
|
||||
|
|
@ -7354,9 +7366,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_floor (type, { 1024, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_ceil (type, { 1024, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_round (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_round (type, { 1024, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_trunc (type, { 1024, 1024, 1, 1 }));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
|
||||
|
|
@ -7501,13 +7517,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
|
||||
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
|
||||
for (uint32_t i = 4; i <= 1024*1024; i *= 2) {
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i-1, 1, 1, 1}));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i, 1, 1, 1}));
|
||||
}
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
|
||||
|
|
@ -7556,6 +7574,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
|
||||
test_cases.emplace_back(new test_roll());
|
||||
test_cases.emplace_back(new test_arange());
|
||||
test_cases.emplace_back(new test_arange(GGML_TYPE_F32, 0.0f, 1048576.0f, 1.0f));
|
||||
test_cases.emplace_back(new test_timestep_embedding());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
|
||||
|
|
@ -7583,6 +7602,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_fill(0.0f));
|
||||
test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 }));
|
||||
test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 }));
|
||||
test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 }));
|
||||
|
||||
test_cases.emplace_back(new test_solve_tri());
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 }));
|
||||
|
|
|
|||
|
|
@ -147,11 +147,15 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
auto * mem = llama_get_memory(ctx);
|
||||
|
||||
llama_memory_t mem = llama_get_memory(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// note: the time for chat template initialization is not negligible:
|
||||
auto chat_templates = common_chat_templates_init(model, params.chat_template);
|
||||
|
||||
// start measuring performance timings from here
|
||||
llama_perf_context_reset(ctx);
|
||||
|
||||
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
|
||||
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -25,3 +25,4 @@ vite.config.ts.timestamp-*
|
|||
|
||||
*storybook.log
|
||||
storybook-static
|
||||
*.code-workspace
|
||||
|
|
@ -2109,9 +2109,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@sveltejs/kit": {
|
||||
"version": "2.48.4",
|
||||
"resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.4.tgz",
|
||||
"integrity": "sha512-TGFX1pZUt9qqY20Cv5NyYvy0iLWHf2jXi8s+eCGsig7jQMdwZWKUFMR6TbvFNhfDSUpc1sH/Y5EHv20g3HHA3g==",
|
||||
"version": "2.48.5",
|
||||
"resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.5.tgz",
|
||||
"integrity": "sha512-/rnwfSWS3qwUSzvHynUTORF9xSJi7PCR9yXkxUOnRrNqyKmCmh3FPHH+E9BbgqxXfTevGXBqgnlh9kMb+9T5XA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
|
@ -5087,9 +5087,9 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/js-yaml": {
|
||||
"version": "4.1.0",
|
||||
"resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz",
|
||||
"integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==",
|
||||
"version": "4.1.1",
|
||||
"resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz",
|
||||
"integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,273 @@
|
|||
<script lang="ts">
|
||||
import { FileText, Image, Music, FileIcon, Eye } from '@lucide/svelte';
|
||||
import { FileTypeCategory, MimeTypeApplication } from '$lib/enums/files';
|
||||
import { convertPDFToImage } from '$lib/utils/pdf-processing';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { getFileTypeCategory } from '$lib/utils/file-type';
|
||||
|
||||
interface Props {
|
||||
// Either an uploaded file or a stored attachment
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
// For uploaded files
|
||||
preview?: string;
|
||||
name?: string;
|
||||
type?: string;
|
||||
textContent?: string;
|
||||
}
|
||||
|
||||
let { uploadedFile, attachment, preview, name, type, textContent }: Props = $props();
|
||||
|
||||
let displayName = $derived(uploadedFile?.name || attachment?.name || name || 'Unknown File');
|
||||
|
||||
let displayPreview = $derived(
|
||||
uploadedFile?.preview || (attachment?.type === 'imageFile' ? attachment.base64Url : preview)
|
||||
);
|
||||
|
||||
let displayType = $derived(
|
||||
uploadedFile?.type ||
|
||||
(attachment?.type === 'imageFile'
|
||||
? 'image'
|
||||
: attachment?.type === 'textFile'
|
||||
? 'text'
|
||||
: attachment?.type === 'audioFile'
|
||||
? attachment.mimeType || 'audio'
|
||||
: attachment?.type === 'pdfFile'
|
||||
? MimeTypeApplication.PDF
|
||||
: type || 'unknown')
|
||||
);
|
||||
|
||||
let displayTextContent = $derived(
|
||||
uploadedFile?.textContent ||
|
||||
(attachment?.type === 'textFile'
|
||||
? attachment.content
|
||||
: attachment?.type === 'pdfFile'
|
||||
? attachment.content
|
||||
: textContent)
|
||||
);
|
||||
|
||||
let isAudio = $derived(
|
||||
getFileTypeCategory(displayType) === FileTypeCategory.AUDIO || displayType === 'audio'
|
||||
);
|
||||
|
||||
let isImage = $derived(
|
||||
getFileTypeCategory(displayType) === FileTypeCategory.IMAGE || displayType === 'image'
|
||||
);
|
||||
|
||||
let isPdf = $derived(displayType === MimeTypeApplication.PDF);
|
||||
|
||||
let isText = $derived(
|
||||
getFileTypeCategory(displayType) === FileTypeCategory.TEXT || displayType === 'text'
|
||||
);
|
||||
|
||||
let IconComponent = $derived(() => {
|
||||
if (isImage) return Image;
|
||||
if (isText || isPdf) return FileText;
|
||||
if (isAudio) return Music;
|
||||
|
||||
return FileIcon;
|
||||
});
|
||||
|
||||
let pdfViewMode = $state<'text' | 'pages'>('pages');
|
||||
|
||||
let pdfImages = $state<string[]>([]);
|
||||
|
||||
let pdfImagesLoading = $state(false);
|
||||
|
||||
let pdfImagesError = $state<string | null>(null);
|
||||
|
||||
async function loadPdfImages() {
|
||||
if (!isPdf || pdfImages.length > 0 || pdfImagesLoading) return;
|
||||
|
||||
pdfImagesLoading = true;
|
||||
pdfImagesError = null;
|
||||
|
||||
try {
|
||||
let file: File | null = null;
|
||||
|
||||
if (uploadedFile?.file) {
|
||||
file = uploadedFile.file;
|
||||
} else if (attachment?.type === 'pdfFile') {
|
||||
// Check if we have pre-processed images
|
||||
if (attachment.images && Array.isArray(attachment.images)) {
|
||||
pdfImages = attachment.images;
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert base64 back to File for processing
|
||||
if (attachment.base64Data) {
|
||||
const base64Data = attachment.base64Data;
|
||||
const byteCharacters = atob(base64Data);
|
||||
const byteNumbers = new Array(byteCharacters.length);
|
||||
for (let i = 0; i < byteCharacters.length; i++) {
|
||||
byteNumbers[i] = byteCharacters.charCodeAt(i);
|
||||
}
|
||||
const byteArray = new Uint8Array(byteNumbers);
|
||||
file = new File([byteArray], displayName, { type: MimeTypeApplication.PDF });
|
||||
}
|
||||
}
|
||||
|
||||
if (file) {
|
||||
pdfImages = await convertPDFToImage(file);
|
||||
} else {
|
||||
throw new Error('No PDF file available for conversion');
|
||||
}
|
||||
} catch (error) {
|
||||
pdfImagesError = error instanceof Error ? error.message : 'Failed to load PDF images';
|
||||
} finally {
|
||||
pdfImagesLoading = false;
|
||||
}
|
||||
}
|
||||
|
||||
export function reset() {
|
||||
pdfImages = [];
|
||||
pdfImagesLoading = false;
|
||||
pdfImagesError = null;
|
||||
pdfViewMode = 'pages';
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (isPdf && pdfViewMode === 'pages') {
|
||||
loadPdfImages();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="space-y-4">
|
||||
<div class="flex items-center justify-end gap-6">
|
||||
{#if isPdf}
|
||||
<div class="flex items-center gap-2">
|
||||
<Button
|
||||
variant={pdfViewMode === 'text' ? 'default' : 'outline'}
|
||||
size="sm"
|
||||
onclick={() => (pdfViewMode = 'text')}
|
||||
disabled={pdfImagesLoading}
|
||||
>
|
||||
<FileText class="mr-1 h-4 w-4" />
|
||||
|
||||
Text
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant={pdfViewMode === 'pages' ? 'default' : 'outline'}
|
||||
size="sm"
|
||||
onclick={() => {
|
||||
pdfViewMode = 'pages';
|
||||
loadPdfImages();
|
||||
}}
|
||||
disabled={pdfImagesLoading}
|
||||
>
|
||||
{#if pdfImagesLoading}
|
||||
<div
|
||||
class="mr-1 h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent"
|
||||
></div>
|
||||
{:else}
|
||||
<Eye class="mr-1 h-4 w-4" />
|
||||
{/if}
|
||||
|
||||
Pages
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="flex-1 overflow-auto">
|
||||
{#if isImage && displayPreview}
|
||||
<div class="flex items-center justify-center">
|
||||
<img
|
||||
src={displayPreview}
|
||||
alt={displayName}
|
||||
class="max-h-full rounded-lg object-contain shadow-lg"
|
||||
/>
|
||||
</div>
|
||||
{:else if isPdf && pdfViewMode === 'pages'}
|
||||
{#if pdfImagesLoading}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<div
|
||||
class="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"
|
||||
></div>
|
||||
|
||||
<p class="text-muted-foreground">Converting PDF to images...</p>
|
||||
</div>
|
||||
</div>
|
||||
{:else if pdfImagesError}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
|
||||
<p class="mb-4 text-muted-foreground">Failed to load PDF images</p>
|
||||
|
||||
<p class="text-sm text-muted-foreground">{pdfImagesError}</p>
|
||||
|
||||
<Button class="mt-4" onclick={() => (pdfViewMode = 'text')}>View as Text</Button>
|
||||
</div>
|
||||
</div>
|
||||
{:else if pdfImages.length > 0}
|
||||
<div class="max-h-[70vh] space-y-4 overflow-auto">
|
||||
{#each pdfImages as image, index (image)}
|
||||
<div class="text-center">
|
||||
<p class="mb-2 text-sm text-muted-foreground">Page {index + 1}</p>
|
||||
|
||||
<img
|
||||
src={image}
|
||||
alt="PDF Page {index + 1}"
|
||||
class="mx-auto max-w-full rounded-lg shadow-lg"
|
||||
/>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
|
||||
<p class="mb-4 text-muted-foreground">No PDF pages available</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent}
|
||||
<div
|
||||
class="max-h-[60vh] overflow-auto rounded-lg bg-muted p-4 font-mono text-sm break-words whitespace-pre-wrap"
|
||||
>
|
||||
{displayTextContent}
|
||||
</div>
|
||||
{:else if isAudio}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="w-full max-w-md text-center">
|
||||
<Music class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
|
||||
{#if attachment?.type === 'audioFile'}
|
||||
<audio
|
||||
controls
|
||||
class="mb-4 w-full"
|
||||
src="data:{attachment.mimeType};base64,{attachment.base64Data}"
|
||||
>
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
{:else if uploadedFile?.preview}
|
||||
<audio controls class="mb-4 w-full" src={uploadedFile.preview}>
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
{:else}
|
||||
<p class="mb-4 text-muted-foreground">Audio preview not available</p>
|
||||
{/if}
|
||||
|
||||
<p class="text-sm text-muted-foreground">
|
||||
{displayName}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
{#if IconComponent}
|
||||
<IconComponent class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
{/if}
|
||||
|
||||
<p class="mb-4 text-muted-foreground">Preview not available for this file type</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -1,314 +0,0 @@
|
|||
<script lang="ts">
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { FileText, Image, Music, FileIcon, Eye } from '@lucide/svelte';
|
||||
import { FileTypeCategory, MimeTypeApplication } from '$lib/enums/files';
|
||||
import { convertPDFToImage } from '$lib/utils/pdf-processing';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { getFileTypeCategory } from '$lib/utils/file-type';
|
||||
import { formatFileSize } from '$lib/utils/file-preview';
|
||||
|
||||
interface Props {
|
||||
open: boolean;
|
||||
// Either an uploaded file or a stored attachment
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
// For uploaded files
|
||||
preview?: string;
|
||||
name?: string;
|
||||
type?: string;
|
||||
size?: number;
|
||||
textContent?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
open = $bindable(),
|
||||
uploadedFile,
|
||||
attachment,
|
||||
preview,
|
||||
name,
|
||||
type,
|
||||
size,
|
||||
textContent
|
||||
}: Props = $props();
|
||||
|
||||
let displayName = $derived(uploadedFile?.name || attachment?.name || name || 'Unknown File');
|
||||
|
||||
let displayPreview = $derived(
|
||||
uploadedFile?.preview || (attachment?.type === 'imageFile' ? attachment.base64Url : preview)
|
||||
);
|
||||
|
||||
let displayType = $derived(
|
||||
uploadedFile?.type ||
|
||||
(attachment?.type === 'imageFile'
|
||||
? 'image'
|
||||
: attachment?.type === 'textFile'
|
||||
? 'text'
|
||||
: attachment?.type === 'audioFile'
|
||||
? attachment.mimeType || 'audio'
|
||||
: attachment?.type === 'pdfFile'
|
||||
? MimeTypeApplication.PDF
|
||||
: type || 'unknown')
|
||||
);
|
||||
|
||||
let displaySize = $derived(uploadedFile?.size || size);
|
||||
|
||||
let displayTextContent = $derived(
|
||||
uploadedFile?.textContent ||
|
||||
(attachment?.type === 'textFile'
|
||||
? attachment.content
|
||||
: attachment?.type === 'pdfFile'
|
||||
? attachment.content
|
||||
: textContent)
|
||||
);
|
||||
|
||||
let isAudio = $derived(
|
||||
getFileTypeCategory(displayType) === FileTypeCategory.AUDIO || displayType === 'audio'
|
||||
);
|
||||
|
||||
let isImage = $derived(
|
||||
getFileTypeCategory(displayType) === FileTypeCategory.IMAGE || displayType === 'image'
|
||||
);
|
||||
|
||||
let isPdf = $derived(displayType === MimeTypeApplication.PDF);
|
||||
|
||||
let isText = $derived(
|
||||
getFileTypeCategory(displayType) === FileTypeCategory.TEXT || displayType === 'text'
|
||||
);
|
||||
|
||||
let IconComponent = $derived(() => {
|
||||
if (isImage) return Image;
|
||||
if (isText || isPdf) return FileText;
|
||||
if (isAudio) return Music;
|
||||
|
||||
return FileIcon;
|
||||
});
|
||||
|
||||
let pdfViewMode = $state<'text' | 'pages'>('pages');
|
||||
|
||||
let pdfImages = $state<string[]>([]);
|
||||
|
||||
let pdfImagesLoading = $state(false);
|
||||
|
||||
let pdfImagesError = $state<string | null>(null);
|
||||
|
||||
async function loadPdfImages() {
|
||||
if (!isPdf || pdfImages.length > 0 || pdfImagesLoading) return;
|
||||
|
||||
pdfImagesLoading = true;
|
||||
pdfImagesError = null;
|
||||
|
||||
try {
|
||||
let file: File | null = null;
|
||||
|
||||
if (uploadedFile?.file) {
|
||||
file = uploadedFile.file;
|
||||
} else if (attachment?.type === 'pdfFile') {
|
||||
// Check if we have pre-processed images
|
||||
if (attachment.images && Array.isArray(attachment.images)) {
|
||||
pdfImages = attachment.images;
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert base64 back to File for processing
|
||||
if (attachment.base64Data) {
|
||||
const base64Data = attachment.base64Data;
|
||||
const byteCharacters = atob(base64Data);
|
||||
const byteNumbers = new Array(byteCharacters.length);
|
||||
for (let i = 0; i < byteCharacters.length; i++) {
|
||||
byteNumbers[i] = byteCharacters.charCodeAt(i);
|
||||
}
|
||||
const byteArray = new Uint8Array(byteNumbers);
|
||||
file = new File([byteArray], displayName, { type: MimeTypeApplication.PDF });
|
||||
}
|
||||
}
|
||||
|
||||
if (file) {
|
||||
pdfImages = await convertPDFToImage(file);
|
||||
} else {
|
||||
throw new Error('No PDF file available for conversion');
|
||||
}
|
||||
} catch (error) {
|
||||
pdfImagesError = error instanceof Error ? error.message : 'Failed to load PDF images';
|
||||
} finally {
|
||||
pdfImagesLoading = false;
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (open) {
|
||||
pdfImages = [];
|
||||
pdfImagesLoading = false;
|
||||
pdfImagesError = null;
|
||||
pdfViewMode = 'pages';
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (open && isPdf && pdfViewMode === 'pages') {
|
||||
loadPdfImages();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<Dialog.Root bind:open>
|
||||
<Dialog.Content class="grid max-h-[90vh] max-w-5xl overflow-hidden !p-10 sm:w-auto sm:max-w-6xl">
|
||||
<Dialog.Header class="flex-shrink-0">
|
||||
<div class="flex items-center justify-between gap-6">
|
||||
<div class="flex items-center gap-3">
|
||||
{#if IconComponent}
|
||||
<IconComponent class="h-5 w-5 text-muted-foreground" />
|
||||
{/if}
|
||||
|
||||
<div>
|
||||
<Dialog.Title class="text-left">{displayName}</Dialog.Title>
|
||||
|
||||
<div class="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<span>{displayType}</span>
|
||||
|
||||
{#if displaySize}
|
||||
<span>•</span>
|
||||
|
||||
<span>{formatFileSize(displaySize)}</span>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if isPdf}
|
||||
<div class="flex items-center gap-2">
|
||||
<Button
|
||||
variant={pdfViewMode === 'text' ? 'default' : 'outline'}
|
||||
size="sm"
|
||||
onclick={() => (pdfViewMode = 'text')}
|
||||
disabled={pdfImagesLoading}
|
||||
>
|
||||
<FileText class="mr-1 h-4 w-4" />
|
||||
|
||||
Text
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant={pdfViewMode === 'pages' ? 'default' : 'outline'}
|
||||
size="sm"
|
||||
onclick={() => {
|
||||
pdfViewMode = 'pages';
|
||||
loadPdfImages();
|
||||
}}
|
||||
disabled={pdfImagesLoading}
|
||||
>
|
||||
{#if pdfImagesLoading}
|
||||
<div
|
||||
class="mr-1 h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent"
|
||||
></div>
|
||||
{:else}
|
||||
<Eye class="mr-1 h-4 w-4" />
|
||||
{/if}
|
||||
|
||||
Pages
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</Dialog.Header>
|
||||
|
||||
<div class="flex-1 overflow-auto">
|
||||
{#if isImage && displayPreview}
|
||||
<div class="flex items-center justify-center">
|
||||
<img
|
||||
src={displayPreview}
|
||||
alt={displayName}
|
||||
class="max-h-full rounded-lg object-contain shadow-lg"
|
||||
/>
|
||||
</div>
|
||||
{:else if isPdf && pdfViewMode === 'pages'}
|
||||
{#if pdfImagesLoading}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<div
|
||||
class="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"
|
||||
></div>
|
||||
|
||||
<p class="text-muted-foreground">Converting PDF to images...</p>
|
||||
</div>
|
||||
</div>
|
||||
{:else if pdfImagesError}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
|
||||
<p class="mb-4 text-muted-foreground">Failed to load PDF images</p>
|
||||
|
||||
<p class="text-sm text-muted-foreground">{pdfImagesError}</p>
|
||||
|
||||
<Button class="mt-4" onclick={() => (pdfViewMode = 'text')}>View as Text</Button>
|
||||
</div>
|
||||
</div>
|
||||
{:else if pdfImages.length > 0}
|
||||
<div class="max-h-[70vh] space-y-4 overflow-auto">
|
||||
{#each pdfImages as image, index (image)}
|
||||
<div class="text-center">
|
||||
<p class="mb-2 text-sm text-muted-foreground">Page {index + 1}</p>
|
||||
|
||||
<img
|
||||
src={image}
|
||||
alt="PDF Page {index + 1}"
|
||||
class="mx-auto max-w-full rounded-lg shadow-lg"
|
||||
/>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
|
||||
<p class="mb-4 text-muted-foreground">No PDF pages available</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent}
|
||||
<div
|
||||
class="max-h-[60vh] overflow-auto rounded-lg bg-muted p-4 font-mono text-sm break-words whitespace-pre-wrap"
|
||||
>
|
||||
{displayTextContent}
|
||||
</div>
|
||||
{:else if isAudio}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="w-full max-w-md text-center">
|
||||
<Music class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
|
||||
{#if attachment?.type === 'audioFile'}
|
||||
<audio
|
||||
controls
|
||||
class="mb-4 w-full"
|
||||
src="data:{attachment.mimeType};base64,{attachment.base64Data}"
|
||||
>
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
{:else if uploadedFile?.preview}
|
||||
<audio controls class="mb-4 w-full" src={uploadedFile.preview}>
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
{:else}
|
||||
<p class="mb-4 text-muted-foreground">Audio preview not available</p>
|
||||
{/if}
|
||||
|
||||
<p class="text-sm text-muted-foreground">
|
||||
{displayName}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
{#if IconComponent}
|
||||
<IconComponent class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
{/if}
|
||||
|
||||
<p class="mb-4 text-muted-foreground">Preview not available for this file type</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog.Root>
|
||||
|
|
@ -1,11 +1,10 @@
|
|||
<script lang="ts">
|
||||
import { ChatAttachmentImagePreview, ChatAttachmentFilePreview } from '$lib/components/app';
|
||||
import { ChatAttachmentThumbnailImage, ChatAttachmentThumbnailFile } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { ChevronLeft, ChevronRight } from '@lucide/svelte';
|
||||
import { FileTypeCategory } from '$lib/enums/files';
|
||||
import { getFileTypeCategory } from '$lib/utils/file-type';
|
||||
import ChatAttachmentPreviewDialog from './ChatAttachmentPreviewDialog.svelte';
|
||||
import ChatAttachmentsViewAllDialog from './ChatAttachmentsViewAllDialog.svelte';
|
||||
import { DialogChatAttachmentPreview, DialogChatAttachmentsViewAll } from '$lib/components/app';
|
||||
import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat';
|
||||
|
||||
interface Props {
|
||||
|
|
@ -200,7 +199,7 @@
|
|||
>
|
||||
{#each displayItems as item (item.id)}
|
||||
{#if item.isImage && item.preview}
|
||||
<ChatAttachmentImagePreview
|
||||
<ChatAttachmentThumbnailImage
|
||||
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
|
|
@ -213,7 +212,7 @@
|
|||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{:else}
|
||||
<ChatAttachmentFilePreview
|
||||
<ChatAttachmentThumbnailFile
|
||||
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
|
|
@ -256,7 +255,7 @@
|
|||
{/if}
|
||||
|
||||
{#if previewItem}
|
||||
<ChatAttachmentPreviewDialog
|
||||
<DialogChatAttachmentPreview
|
||||
bind:open={previewDialogOpen}
|
||||
uploadedFile={previewItem.uploadedFile}
|
||||
attachment={previewItem.attachment}
|
||||
|
|
@ -268,7 +267,7 @@
|
|||
/>
|
||||
{/if}
|
||||
|
||||
<ChatAttachmentsViewAllDialog
|
||||
<DialogChatAttachmentsViewAll
|
||||
bind:open={viewAllDialogOpen}
|
||||
{uploadedFiles}
|
||||
{attachments}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
<script lang="ts">
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { ChatAttachmentImagePreview, ChatAttachmentFilePreview } from '$lib/components/app';
|
||||
import {
|
||||
ChatAttachmentThumbnailImage,
|
||||
ChatAttachmentThumbnailFile,
|
||||
DialogChatAttachmentPreview
|
||||
} from '$lib/components/app';
|
||||
import { FileTypeCategory } from '$lib/enums/files';
|
||||
import { getFileTypeCategory } from '$lib/utils/file-type';
|
||||
import ChatAttachmentPreviewDialog from './ChatAttachmentPreviewDialog.svelte';
|
||||
import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat';
|
||||
|
||||
interface Props {
|
||||
open?: boolean;
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
readonly?: boolean;
|
||||
|
|
@ -18,7 +19,6 @@
|
|||
}
|
||||
|
||||
let {
|
||||
open = $bindable(false),
|
||||
uploadedFiles = [],
|
||||
attachments = [],
|
||||
readonly = false,
|
||||
|
|
@ -127,70 +127,57 @@
|
|||
}
|
||||
</script>
|
||||
|
||||
<Dialog.Root bind:open>
|
||||
<Dialog.Portal>
|
||||
<Dialog.Overlay />
|
||||
|
||||
<Dialog.Content class="flex !max-h-[90vh] !max-w-6xl flex-col">
|
||||
<Dialog.Header>
|
||||
<Dialog.Title>All Attachments ({displayItems.length})</Dialog.Title>
|
||||
<Dialog.Description class="text-sm text-muted-foreground">
|
||||
View and manage all attached files
|
||||
</Dialog.Description>
|
||||
</Dialog.Header>
|
||||
|
||||
<div class="min-h-0 flex-1 space-y-6 overflow-y-auto px-1">
|
||||
{#if fileItems.length > 0}
|
||||
<div>
|
||||
<h3 class="mb-3 text-sm font-medium text-foreground">Files ({fileItems.length})</h3>
|
||||
<div class="flex flex-wrap items-start gap-3">
|
||||
{#each fileItems as item (item.id)}
|
||||
<ChatAttachmentFilePreview
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
type={item.type}
|
||||
size={item.size}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
textContent={item.textContent}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if imageItems.length > 0}
|
||||
<div>
|
||||
<h3 class="mb-3 text-sm font-medium text-foreground">Images ({imageItems.length})</h3>
|
||||
<div class="flex flex-wrap items-start gap-3">
|
||||
{#each imageItems as item (item.id)}
|
||||
{#if item.preview}
|
||||
<ChatAttachmentImagePreview
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
preview={item.preview}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
height={imageHeight}
|
||||
width={imageWidth}
|
||||
{imageClass}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
<div class="space-y-4">
|
||||
<div class="min-h-0 flex-1 space-y-6 overflow-y-auto px-1">
|
||||
{#if fileItems.length > 0}
|
||||
<div>
|
||||
<h3 class="mb-3 text-sm font-medium text-foreground">Files ({fileItems.length})</h3>
|
||||
<div class="flex flex-wrap items-start gap-3">
|
||||
{#each fileItems as item (item.id)}
|
||||
<ChatAttachmentThumbnailFile
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
type={item.type}
|
||||
size={item.size}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
textContent={item.textContent}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog.Portal>
|
||||
</Dialog.Root>
|
||||
{/if}
|
||||
|
||||
{#if imageItems.length > 0}
|
||||
<div>
|
||||
<h3 class="mb-3 text-sm font-medium text-foreground">Images ({imageItems.length})</h3>
|
||||
<div class="flex flex-wrap items-start gap-3">
|
||||
{#each imageItems as item (item.id)}
|
||||
{#if item.preview}
|
||||
<ChatAttachmentThumbnailImage
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
preview={item.preview}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
height={imageHeight}
|
||||
width={imageWidth}
|
||||
{imageClass}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if previewItem}
|
||||
<ChatAttachmentPreviewDialog
|
||||
<DialogChatAttachmentPreview
|
||||
bind:open={previewDialogOpen}
|
||||
uploadedFile={previewItem.uploadedFile}
|
||||
attachment={previewItem.attachment}
|
||||
|
|
@ -1,9 +1,11 @@
|
|||
<script lang="ts">
|
||||
import { Square, ArrowUp } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import ChatFormActionFileAttachments from './ChatFormActionFileAttachments.svelte';
|
||||
import ChatFormActionRecord from './ChatFormActionRecord.svelte';
|
||||
import ChatFormModelSelector from './ChatFormModelSelector.svelte';
|
||||
import {
|
||||
ChatFormActionFileAttachments,
|
||||
ChatFormActionRecord,
|
||||
ChatFormModelSelector
|
||||
} from '$lib/components/app';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import type { FileTypeCategory } from '$lib/enums/files';
|
||||
|
||||
|
|
@ -10,6 +10,7 @@
|
|||
class?: string;
|
||||
message: DatabaseMessage;
|
||||
onCopy?: (message: DatabaseMessage) => void;
|
||||
onContinueAssistantMessage?: (message: DatabaseMessage) => void;
|
||||
onDelete?: (message: DatabaseMessage) => void;
|
||||
onEditWithBranching?: (message: DatabaseMessage, newContent: string) => void;
|
||||
onEditWithReplacement?: (
|
||||
|
|
@ -17,6 +18,7 @@
|
|||
newContent: string,
|
||||
shouldBranch: boolean
|
||||
) => void;
|
||||
onEditUserMessagePreserveResponses?: (message: DatabaseMessage, newContent: string) => void;
|
||||
onNavigateToSibling?: (siblingId: string) => void;
|
||||
onRegenerateWithBranching?: (message: DatabaseMessage) => void;
|
||||
siblingInfo?: ChatMessageSiblingInfo | null;
|
||||
|
|
@ -26,9 +28,11 @@
|
|||
class: className = '',
|
||||
message,
|
||||
onCopy,
|
||||
onContinueAssistantMessage,
|
||||
onDelete,
|
||||
onEditWithBranching,
|
||||
onEditWithReplacement,
|
||||
onEditUserMessagePreserveResponses,
|
||||
onNavigateToSibling,
|
||||
onRegenerateWithBranching,
|
||||
siblingInfo = null
|
||||
|
|
@ -133,17 +137,33 @@
|
|||
onRegenerateWithBranching?.(message);
|
||||
}
|
||||
|
||||
function handleContinue() {
|
||||
onContinueAssistantMessage?.(message);
|
||||
}
|
||||
|
||||
function handleSaveEdit() {
|
||||
if (message.role === 'user') {
|
||||
// For user messages, trim to avoid accidental whitespace
|
||||
onEditWithBranching?.(message, editedContent.trim());
|
||||
} else {
|
||||
onEditWithReplacement?.(message, editedContent.trim(), shouldBranchAfterEdit);
|
||||
// For assistant messages, preserve exact content including trailing whitespace
|
||||
// This is important for the Continue feature to work properly
|
||||
onEditWithReplacement?.(message, editedContent, shouldBranchAfterEdit);
|
||||
}
|
||||
|
||||
isEditing = false;
|
||||
shouldBranchAfterEdit = false;
|
||||
}
|
||||
|
||||
function handleSaveEditOnly() {
|
||||
if (message.role === 'user') {
|
||||
// For user messages, trim to avoid accidental whitespace
|
||||
onEditUserMessagePreserveResponses?.(message, editedContent.trim());
|
||||
}
|
||||
|
||||
isEditing = false;
|
||||
}
|
||||
|
||||
function handleShowDeleteDialogChange(show: boolean) {
|
||||
showDeleteDialog = show;
|
||||
}
|
||||
|
|
@ -166,6 +186,7 @@
|
|||
onEditedContentChange={handleEditedContentChange}
|
||||
{onNavigateToSibling}
|
||||
onSaveEdit={handleSaveEdit}
|
||||
onSaveEditOnly={handleSaveEditOnly}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
|
|
@ -181,6 +202,7 @@
|
|||
messageContent={message.content}
|
||||
onCancelEdit={handleCancelEdit}
|
||||
onConfirmDelete={handleConfirmDelete}
|
||||
onContinue={handleContinue}
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
<script lang="ts">
|
||||
import { Edit, Copy, RefreshCw, Trash2 } from '@lucide/svelte';
|
||||
import { ActionButton, ConfirmationDialog } from '$lib/components/app';
|
||||
import ChatMessageBranchingControls from './ChatMessageBranchingControls.svelte';
|
||||
import { Edit, Copy, RefreshCw, Trash2, ArrowRight } from '@lucide/svelte';
|
||||
import {
|
||||
ActionButton,
|
||||
ChatMessageBranchingControls,
|
||||
DialogConfirmation
|
||||
} from '$lib/components/app';
|
||||
|
||||
interface Props {
|
||||
role: 'user' | 'assistant';
|
||||
|
|
@ -18,6 +21,7 @@
|
|||
onCopy: () => void;
|
||||
onEdit?: () => void;
|
||||
onRegenerate?: () => void;
|
||||
onContinue?: () => void;
|
||||
onDelete: () => void;
|
||||
onConfirmDelete: () => void;
|
||||
onNavigateToSibling?: (siblingId: string) => void;
|
||||
|
|
@ -31,6 +35,7 @@
|
|||
onCopy,
|
||||
onEdit,
|
||||
onConfirmDelete,
|
||||
onContinue,
|
||||
onDelete,
|
||||
onNavigateToSibling,
|
||||
onShowDeleteDialogChange,
|
||||
|
|
@ -69,12 +74,16 @@
|
|||
<ActionButton icon={RefreshCw} tooltip="Regenerate" onclick={onRegenerate} />
|
||||
{/if}
|
||||
|
||||
{#if role === 'assistant' && onContinue}
|
||||
<ActionButton icon={ArrowRight} tooltip="Continue" onclick={onContinue} />
|
||||
{/if}
|
||||
|
||||
<ActionButton icon={Trash2} tooltip="Delete" onclick={onDelete} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<ConfirmationDialog
|
||||
<DialogConfirmation
|
||||
bind:open={showDeleteDialog}
|
||||
title="Delete Message"
|
||||
description={deletionInfo && deletionInfo.totalCount > 1
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
import { ChatMessageThinkingBlock, MarkdownContent } from '$lib/components/app';
|
||||
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
|
||||
import { isLoading } from '$lib/stores/chat.svelte';
|
||||
import autoResizeTextarea from '$lib/utils/autoresize-textarea';
|
||||
import { fade } from 'svelte/transition';
|
||||
import {
|
||||
Check,
|
||||
|
|
@ -39,6 +40,7 @@
|
|||
onCancelEdit?: () => void;
|
||||
onCopy: () => void;
|
||||
onConfirmDelete: () => void;
|
||||
onContinue?: () => void;
|
||||
onDelete: () => void;
|
||||
onEdit?: () => void;
|
||||
onEditKeydown?: (event: KeyboardEvent) => void;
|
||||
|
|
@ -65,6 +67,7 @@
|
|||
messageContent,
|
||||
onCancelEdit,
|
||||
onConfirmDelete,
|
||||
onContinue,
|
||||
onCopy,
|
||||
onDelete,
|
||||
onEdit,
|
||||
|
|
@ -107,6 +110,12 @@
|
|||
void copyToClipboard(model ?? '');
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (isEditing && textareaElement) {
|
||||
autoResizeTextarea(textareaElement);
|
||||
}
|
||||
});
|
||||
|
||||
function formatToolCallBadge(toolCall: ApiChatCompletionToolCall, index: number) {
|
||||
const callNumber = index + 1;
|
||||
const functionName = toolCall.function?.name?.trim();
|
||||
|
|
@ -190,7 +199,10 @@
|
|||
bind:value={editedContent}
|
||||
class="min-h-[50vh] w-full resize-y rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
|
||||
onkeydown={onEditKeydown}
|
||||
oninput={(e) => onEditedContentChange?.(e.currentTarget.value)}
|
||||
oninput={(e) => {
|
||||
autoResizeTextarea(e.currentTarget);
|
||||
onEditedContentChange?.(e.currentTarget.value);
|
||||
}}
|
||||
placeholder="Edit assistant message..."
|
||||
></textarea>
|
||||
|
||||
|
|
@ -335,6 +347,9 @@
|
|||
{onCopy}
|
||||
{onEdit}
|
||||
{onRegenerate}
|
||||
onContinue={currentConfig.enableContinueGeneration && !thinkingContent
|
||||
? onContinue
|
||||
: undefined}
|
||||
{onDelete}
|
||||
{onConfirmDelete}
|
||||
{onNavigateToSibling}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
<script lang="ts">
|
||||
import { Check, X } from '@lucide/svelte';
|
||||
import { Check, X, Send } from '@lucide/svelte';
|
||||
import { Card } from '$lib/components/ui/card';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app';
|
||||
import { INPUT_CLASSES } from '$lib/constants/input-classes';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import autoResizeTextarea from '$lib/utils/autoresize-textarea';
|
||||
import ChatMessageActions from './ChatMessageActions.svelte';
|
||||
|
||||
interface Props {
|
||||
|
|
@ -22,6 +23,7 @@
|
|||
} | null;
|
||||
onCancelEdit: () => void;
|
||||
onSaveEdit: () => void;
|
||||
onSaveEditOnly?: () => void;
|
||||
onEditKeydown: (event: KeyboardEvent) => void;
|
||||
onEditedContentChange: (content: string) => void;
|
||||
onCopy: () => void;
|
||||
|
|
@ -43,6 +45,7 @@
|
|||
deletionInfo,
|
||||
onCancelEdit,
|
||||
onSaveEdit,
|
||||
onSaveEditOnly,
|
||||
onEditKeydown,
|
||||
onEditedContentChange,
|
||||
onCopy,
|
||||
|
|
@ -58,6 +61,12 @@
|
|||
let messageElement: HTMLElement | undefined = $state();
|
||||
const currentConfig = config();
|
||||
|
||||
$effect(() => {
|
||||
if (isEditing && textareaElement) {
|
||||
autoResizeTextarea(textareaElement);
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (!messageElement || !message.content.trim()) return;
|
||||
|
||||
|
|
@ -95,20 +104,34 @@
|
|||
bind:value={editedContent}
|
||||
class="min-h-[60px] w-full resize-none rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
|
||||
onkeydown={onEditKeydown}
|
||||
oninput={(e) => onEditedContentChange(e.currentTarget.value)}
|
||||
oninput={(e) => {
|
||||
autoResizeTextarea(e.currentTarget);
|
||||
onEditedContentChange(e.currentTarget.value);
|
||||
}}
|
||||
placeholder="Edit your message..."
|
||||
></textarea>
|
||||
|
||||
<div class="mt-2 flex justify-end gap-2">
|
||||
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="outline">
|
||||
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="ghost">
|
||||
<X class="mr-1 h-3 w-3" />
|
||||
|
||||
Cancel
|
||||
</Button>
|
||||
|
||||
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
|
||||
<Check class="mr-1 h-3 w-3" />
|
||||
{#if onSaveEditOnly}
|
||||
<Button
|
||||
class="h-8 px-3"
|
||||
onclick={onSaveEditOnly}
|
||||
disabled={!editedContent.trim()}
|
||||
size="sm"
|
||||
variant="outline"
|
||||
>
|
||||
<Check class="mr-1 h-3 w-3" />
|
||||
Save
|
||||
</Button>
|
||||
{/if}
|
||||
|
||||
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
|
||||
<Send class="mr-1 h-3 w-3" />
|
||||
Send
|
||||
</Button>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@
|
|||
import { DatabaseStore } from '$lib/stores/database';
|
||||
import {
|
||||
activeConversation,
|
||||
continueAssistantMessage,
|
||||
deleteMessage,
|
||||
navigateToSibling,
|
||||
editMessageWithBranching,
|
||||
editAssistantMessage,
|
||||
editMessageWithBranching,
|
||||
editUserMessagePreserveResponses,
|
||||
navigateToSibling,
|
||||
regenerateMessageWithBranching
|
||||
} from '$lib/stores/chat.svelte';
|
||||
import { getMessageSiblings } from '$lib/utils/branching';
|
||||
|
|
@ -93,6 +95,26 @@
|
|||
|
||||
refreshAllMessages();
|
||||
}
|
||||
|
||||
async function handleContinueAssistantMessage(message: DatabaseMessage) {
|
||||
onUserAction?.();
|
||||
|
||||
await continueAssistantMessage(message.id);
|
||||
|
||||
refreshAllMessages();
|
||||
}
|
||||
|
||||
async function handleEditUserMessagePreserveResponses(
|
||||
message: DatabaseMessage,
|
||||
newContent: string
|
||||
) {
|
||||
onUserAction?.();
|
||||
|
||||
await editUserMessagePreserveResponses(message.id, newContent);
|
||||
|
||||
refreshAllMessages();
|
||||
}
|
||||
|
||||
async function handleDeleteMessage(message: DatabaseMessage) {
|
||||
await deleteMessage(message.id);
|
||||
|
||||
|
|
@ -110,7 +132,9 @@
|
|||
onNavigateToSibling={handleNavigateToSibling}
|
||||
onEditWithBranching={handleEditWithBranching}
|
||||
onEditWithReplacement={handleEditWithReplacement}
|
||||
onEditUserMessagePreserveResponses={handleEditUserMessagePreserveResponses}
|
||||
onRegenerateWithBranching={handleRegenerateWithBranching}
|
||||
onContinueAssistantMessage={handleContinueAssistantMessage}
|
||||
/>
|
||||
{/each}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -5,13 +5,13 @@
|
|||
ChatScreenHeader,
|
||||
ChatScreenWarning,
|
||||
ChatMessages,
|
||||
ChatProcessingInfo,
|
||||
EmptyFileAlertDialog,
|
||||
ChatErrorDialog,
|
||||
ChatScreenProcessingInfo,
|
||||
DialogEmptyFileAlert,
|
||||
DialogChatError,
|
||||
ServerErrorSplash,
|
||||
ServerInfo,
|
||||
ServerLoadingSplash,
|
||||
ConfirmationDialog
|
||||
DialogConfirmation
|
||||
} from '$lib/components/app';
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import {
|
||||
|
|
@ -299,7 +299,7 @@
|
|||
class="pointer-events-none sticky right-0 bottom-0 left-0 mt-auto"
|
||||
in:slide={{ duration: 150, axis: 'y' }}
|
||||
>
|
||||
<ChatProcessingInfo />
|
||||
<ChatScreenProcessingInfo />
|
||||
|
||||
{#if serverWarning()}
|
||||
<ChatScreenWarning class="pointer-events-auto mx-auto max-w-[48rem] px-4" />
|
||||
|
|
@ -432,7 +432,7 @@
|
|||
</AlertDialog.Portal>
|
||||
</AlertDialog.Root>
|
||||
|
||||
<ConfirmationDialog
|
||||
<DialogConfirmation
|
||||
bind:open={showDeleteDialog}
|
||||
title="Delete Conversation"
|
||||
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
|
||||
|
|
@ -444,7 +444,7 @@
|
|||
onCancel={() => (showDeleteDialog = false)}
|
||||
/>
|
||||
|
||||
<EmptyFileAlertDialog
|
||||
<DialogEmptyFileAlert
|
||||
bind:open={showEmptyFileDialog}
|
||||
emptyFiles={emptyFileNames}
|
||||
onOpenChange={(open) => {
|
||||
|
|
@ -454,7 +454,7 @@
|
|||
}}
|
||||
/>
|
||||
|
||||
<ChatErrorDialog
|
||||
<DialogChatError
|
||||
message={activeErrorDialog?.message ?? ''}
|
||||
onOpenChange={handleErrorDialogOpenChange}
|
||||
open={Boolean(activeErrorDialog)}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
<script lang="ts">
|
||||
import { Settings } from '@lucide/svelte';
|
||||
import { ChatSettingsDialog } from '$lib/components/app';
|
||||
import { DialogChatSettings } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
|
||||
let settingsOpen = $state(false);
|
||||
|
|
@ -20,4 +20,4 @@
|
|||
</div>
|
||||
</header>
|
||||
|
||||
<ChatSettingsDialog open={settingsOpen} onOpenChange={(open) => (settingsOpen = open)} />
|
||||
<DialogChatSettings open={settingsOpen} onOpenChange={(open) => (settingsOpen = open)} />
|
||||
|
|
|
|||
|
|
@ -12,20 +12,21 @@
|
|||
ChevronRight,
|
||||
Database
|
||||
} from '@lucide/svelte';
|
||||
import { ChatSettingsFooter, ChatSettingsFields } from '$lib/components/app';
|
||||
import ImportExportTab from './ImportExportTab.svelte';
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import {
|
||||
ChatSettingsFooter,
|
||||
ChatSettingsImportExportTab,
|
||||
ChatSettingsFields
|
||||
} from '$lib/components/app';
|
||||
import { ScrollArea } from '$lib/components/ui/scroll-area';
|
||||
import { config, updateMultipleConfig } from '$lib/stores/settings.svelte';
|
||||
import { setMode } from 'mode-watcher';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
open?: boolean;
|
||||
onSave?: () => void;
|
||||
}
|
||||
|
||||
let { onOpenChange, open = false }: Props = $props();
|
||||
let { onSave }: Props = $props();
|
||||
|
||||
const settingSections: Array<{
|
||||
fields: SettingsFieldConfig[];
|
||||
|
|
@ -52,6 +53,11 @@
|
|||
{ value: 'dark', label: 'Dark', icon: Moon }
|
||||
]
|
||||
},
|
||||
{
|
||||
key: 'pasteLongTextToFileLen',
|
||||
label: 'Paste long text to file length',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'showMessageStats',
|
||||
label: 'Show message generation statistics',
|
||||
|
|
@ -68,14 +74,15 @@
|
|||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'askForTitleConfirmation',
|
||||
label: 'Ask for confirmation before changing conversation title',
|
||||
key: 'showModelInfo',
|
||||
label: 'Show model information',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'pasteLongTextToFileLen',
|
||||
label: 'Paste long text to file length',
|
||||
type: 'input'
|
||||
key: 'enableContinueGeneration',
|
||||
label: 'Enable "Continue" button',
|
||||
type: 'checkbox',
|
||||
isExperimental: true
|
||||
},
|
||||
{
|
||||
key: 'pdfAsImage',
|
||||
|
|
@ -83,13 +90,13 @@
|
|||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'showModelInfo',
|
||||
label: 'Show model information',
|
||||
key: 'renderUserContentAsMarkdown',
|
||||
label: 'Render user content as Markdown',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'renderUserContentAsMarkdown',
|
||||
label: 'Render user content as Markdown',
|
||||
key: 'askForTitleConfirmation',
|
||||
label: 'Ask for confirmation before changing conversation title',
|
||||
type: 'checkbox'
|
||||
}
|
||||
]
|
||||
|
|
@ -263,7 +270,6 @@
|
|||
settingSections.find((section) => section.title === activeSection) || settingSections[0]
|
||||
);
|
||||
let localConfig: SettingsConfigType = $state({ ...config() });
|
||||
let originalTheme: string = $state('');
|
||||
|
||||
let canScrollLeft = $state(false);
|
||||
let canScrollRight = $state(false);
|
||||
|
|
@ -279,18 +285,10 @@
|
|||
localConfig[key] = value;
|
||||
}
|
||||
|
||||
function handleClose() {
|
||||
if (localConfig.theme !== originalTheme) {
|
||||
setMode(originalTheme as 'light' | 'dark' | 'system');
|
||||
}
|
||||
onOpenChange?.(false);
|
||||
}
|
||||
|
||||
function handleReset() {
|
||||
localConfig = { ...config() };
|
||||
|
||||
setMode(localConfig.theme as 'light' | 'dark' | 'system');
|
||||
originalTheme = localConfig.theme as string;
|
||||
}
|
||||
|
||||
function handleSave() {
|
||||
|
|
@ -341,7 +339,7 @@
|
|||
}
|
||||
|
||||
updateMultipleConfig(processedConfig);
|
||||
onOpenChange?.(false);
|
||||
onSave?.();
|
||||
}
|
||||
|
||||
function scrollToCenter(element: HTMLElement) {
|
||||
|
|
@ -377,14 +375,11 @@
|
|||
canScrollRight = scrollLeft < scrollWidth - clientWidth - 1; // -1 for rounding
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (open) {
|
||||
localConfig = { ...config() };
|
||||
originalTheme = config().theme as string;
|
||||
export function reset() {
|
||||
localConfig = { ...config() };
|
||||
|
||||
setTimeout(updateScrollButtons, 100);
|
||||
}
|
||||
});
|
||||
setTimeout(updateScrollButtons, 100);
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (scrollContainer) {
|
||||
|
|
@ -393,120 +388,106 @@
|
|||
});
|
||||
</script>
|
||||
|
||||
<Dialog.Root {open} onOpenChange={handleClose}>
|
||||
<Dialog.Content
|
||||
class="z-999999 flex h-[100dvh] max-h-[100dvh] min-h-[100dvh] flex-col gap-0 rounded-none p-0
|
||||
md:h-[64vh] md:max-h-[64vh] md:min-h-0 md:rounded-lg"
|
||||
style="max-width: 48rem;"
|
||||
>
|
||||
<div class="flex flex-1 flex-col overflow-hidden md:flex-row">
|
||||
<!-- Desktop Sidebar -->
|
||||
<div class="hidden w-64 border-r border-border/30 p-6 md:block">
|
||||
<nav class="space-y-1 py-2">
|
||||
<Dialog.Title class="mb-6 flex items-center gap-2">Settings</Dialog.Title>
|
||||
<div class="flex h-full flex-col overflow-hidden md:flex-row">
|
||||
<!-- Desktop Sidebar -->
|
||||
<div class="hidden w-64 border-r border-border/30 p-6 md:block">
|
||||
<nav class="space-y-1 py-2">
|
||||
{#each settingSections as section (section.title)}
|
||||
<button
|
||||
class="flex w-full cursor-pointer items-center gap-3 rounded-lg px-3 py-2 text-left text-sm transition-colors hover:bg-accent {activeSection ===
|
||||
section.title
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: 'text-muted-foreground'}"
|
||||
onclick={() => (activeSection = section.title)}
|
||||
>
|
||||
<section.icon class="h-4 w-4" />
|
||||
|
||||
{#each settingSections as section (section.title)}
|
||||
<button
|
||||
class="flex w-full cursor-pointer items-center gap-3 rounded-lg px-3 py-2 text-left text-sm transition-colors hover:bg-accent {activeSection ===
|
||||
section.title
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: 'text-muted-foreground'}"
|
||||
onclick={() => (activeSection = section.title)}
|
||||
>
|
||||
<section.icon class="h-4 w-4" />
|
||||
<span class="ml-2">{section.title}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</nav>
|
||||
</div>
|
||||
|
||||
<span class="ml-2">{section.title}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</nav>
|
||||
</div>
|
||||
<!-- Mobile Header with Horizontal Scrollable Menu -->
|
||||
<div class="flex flex-col md:hidden">
|
||||
<div class="border-b border-border/30 py-4">
|
||||
<!-- Horizontal Scrollable Category Menu with Navigation -->
|
||||
<div class="relative flex items-center" style="scroll-padding: 1rem;">
|
||||
<button
|
||||
class="absolute left-2 z-10 flex h-6 w-6 items-center justify-center rounded-full bg-muted shadow-md backdrop-blur-sm transition-opacity hover:bg-accent {canScrollLeft
|
||||
? 'opacity-100'
|
||||
: 'pointer-events-none opacity-0'}"
|
||||
onclick={scrollLeft}
|
||||
aria-label="Scroll left"
|
||||
>
|
||||
<ChevronLeft class="h-4 w-4" />
|
||||
</button>
|
||||
|
||||
<!-- Mobile Header with Horizontal Scrollable Menu -->
|
||||
<div class="flex flex-col md:hidden">
|
||||
<div class="border-b border-border/30 py-4">
|
||||
<Dialog.Title class="mb-6 flex items-center gap-2 px-4">Settings</Dialog.Title>
|
||||
|
||||
<!-- Horizontal Scrollable Category Menu with Navigation -->
|
||||
<div class="relative flex items-center" style="scroll-padding: 1rem;">
|
||||
<button
|
||||
class="absolute left-2 z-10 flex h-6 w-6 items-center justify-center rounded-full bg-muted shadow-md backdrop-blur-sm transition-opacity hover:bg-accent {canScrollLeft
|
||||
? 'opacity-100'
|
||||
: 'pointer-events-none opacity-0'}"
|
||||
onclick={scrollLeft}
|
||||
aria-label="Scroll left"
|
||||
>
|
||||
<ChevronLeft class="h-4 w-4" />
|
||||
</button>
|
||||
|
||||
<div
|
||||
class="scrollbar-hide overflow-x-auto py-2"
|
||||
bind:this={scrollContainer}
|
||||
onscroll={updateScrollButtons}
|
||||
>
|
||||
<div class="flex min-w-max gap-2">
|
||||
{#each settingSections as section (section.title)}
|
||||
<button
|
||||
class="flex cursor-pointer items-center gap-2 rounded-lg px-3 py-2 text-sm whitespace-nowrap transition-colors first:ml-4 last:mr-4 hover:bg-accent {activeSection ===
|
||||
section.title
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: 'text-muted-foreground'}"
|
||||
onclick={(e: MouseEvent) => {
|
||||
activeSection = section.title;
|
||||
scrollToCenter(e.currentTarget as HTMLElement);
|
||||
}}
|
||||
>
|
||||
<section.icon class="h-4 w-4 flex-shrink-0" />
|
||||
<span>{section.title}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button
|
||||
class="absolute right-2 z-10 flex h-6 w-6 items-center justify-center rounded-full bg-muted shadow-md backdrop-blur-sm transition-opacity hover:bg-accent {canScrollRight
|
||||
? 'opacity-100'
|
||||
: 'pointer-events-none opacity-0'}"
|
||||
onclick={scrollRight}
|
||||
aria-label="Scroll right"
|
||||
>
|
||||
<ChevronRight class="h-4 w-4" />
|
||||
</button>
|
||||
<div
|
||||
class="scrollbar-hide overflow-x-auto py-2"
|
||||
bind:this={scrollContainer}
|
||||
onscroll={updateScrollButtons}
|
||||
>
|
||||
<div class="flex min-w-max gap-2">
|
||||
{#each settingSections as section (section.title)}
|
||||
<button
|
||||
class="flex cursor-pointer items-center gap-2 rounded-lg px-3 py-2 text-sm whitespace-nowrap transition-colors first:ml-4 last:mr-4 hover:bg-accent {activeSection ===
|
||||
section.title
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: 'text-muted-foreground'}"
|
||||
onclick={(e: MouseEvent) => {
|
||||
activeSection = section.title;
|
||||
scrollToCenter(e.currentTarget as HTMLElement);
|
||||
}}
|
||||
>
|
||||
<section.icon class="h-4 w-4 flex-shrink-0" />
|
||||
<span>{section.title}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button
|
||||
class="absolute right-2 z-10 flex h-6 w-6 items-center justify-center rounded-full bg-muted shadow-md backdrop-blur-sm transition-opacity hover:bg-accent {canScrollRight
|
||||
? 'opacity-100'
|
||||
: 'pointer-events-none opacity-0'}"
|
||||
onclick={scrollRight}
|
||||
aria-label="Scroll right"
|
||||
>
|
||||
<ChevronRight class="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<ScrollArea class="max-h-[calc(100dvh-13.5rem)] flex-1 md:max-h-[calc(100vh-13.5rem)]">
|
||||
<div class="space-y-6 p-4 md:p-6">
|
||||
<div class="grid">
|
||||
<div class="mb-6 flex hidden items-center gap-2 border-b border-border/30 pb-6 md:flex">
|
||||
<currentSection.icon class="h-5 w-5" />
|
||||
|
||||
<h3 class="text-lg font-semibold">{currentSection.title}</h3>
|
||||
</div>
|
||||
|
||||
{#if currentSection.title === 'Import/Export'}
|
||||
<ImportExportTab />
|
||||
{:else}
|
||||
<div class="space-y-6">
|
||||
<ChatSettingsFields
|
||||
fields={currentSection.fields}
|
||||
{localConfig}
|
||||
onConfigChange={handleConfigChange}
|
||||
onThemeChange={handleThemeChange}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="mt-8 border-t pt-6">
|
||||
<p class="text-xs text-muted-foreground">
|
||||
Settings are saved in browser's localStorage
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<ChatSettingsFooter onReset={handleReset} onSave={handleSave} />
|
||||
</Dialog.Content>
|
||||
</Dialog.Root>
|
||||
<ScrollArea class="max-h-[calc(100dvh-13.5rem)] flex-1 md:max-h-[calc(100vh-13.5rem)]">
|
||||
<div class="space-y-6 p-4 md:p-6">
|
||||
<div class="grid">
|
||||
<div class="mb-6 flex hidden items-center gap-2 border-b border-border/30 pb-6 md:flex">
|
||||
<currentSection.icon class="h-5 w-5" />
|
||||
|
||||
<h3 class="text-lg font-semibold">{currentSection.title}</h3>
|
||||
</div>
|
||||
|
||||
{#if currentSection.title === 'Import/Export'}
|
||||
<ChatSettingsImportExportTab />
|
||||
{:else}
|
||||
<div class="space-y-6">
|
||||
<ChatSettingsFields
|
||||
fields={currentSection.fields}
|
||||
{localConfig}
|
||||
onConfigChange={handleConfigChange}
|
||||
onThemeChange={handleThemeChange}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="mt-8 border-t pt-6">
|
||||
<p class="text-xs text-muted-foreground">Settings are saved in browser's localStorage</p>
|
||||
</div>
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
|
||||
<ChatSettingsFooter onReset={handleReset} onSave={handleSave} />
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
<script lang="ts">
|
||||
import { RotateCcw } from '@lucide/svelte';
|
||||
import { RotateCcw, FlaskConical } from '@lucide/svelte';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
|
|
@ -9,7 +9,7 @@
|
|||
import { supportsVision } from '$lib/stores/server.svelte';
|
||||
import { getParameterInfo, resetParameterToServerDefault } from '$lib/stores/settings.svelte';
|
||||
import { ParameterSyncService } from '$lib/services/parameter-sync';
|
||||
import ParameterSourceIndicator from './ParameterSourceIndicator.svelte';
|
||||
import { ChatSettingsParameterSourceIndicator } from '$lib/components/app';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
|
|
@ -55,11 +55,15 @@
|
|||
})()}
|
||||
|
||||
<div class="flex items-center gap-2">
|
||||
<Label for={field.key} class="text-sm font-medium">
|
||||
<Label for={field.key} class="flex items-center gap-1.5 text-sm font-medium">
|
||||
{field.label}
|
||||
|
||||
{#if field.isExperimental}
|
||||
<FlaskConical class="h-3.5 w-3.5 text-muted-foreground" />
|
||||
{/if}
|
||||
</Label>
|
||||
{#if isCustomRealTime}
|
||||
<ParameterSourceIndicator />
|
||||
<ChatSettingsParameterSourceIndicator />
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
|
|
@ -97,8 +101,12 @@
|
|||
</p>
|
||||
{/if}
|
||||
{:else if field.type === 'textarea'}
|
||||
<Label for={field.key} class="block text-sm font-medium">
|
||||
<Label for={field.key} class="block flex items-center gap-1.5 text-sm font-medium">
|
||||
{field.label}
|
||||
|
||||
{#if field.isExperimental}
|
||||
<FlaskConical class="h-3.5 w-3.5 text-muted-foreground" />
|
||||
{/if}
|
||||
</Label>
|
||||
|
||||
<Textarea
|
||||
|
|
@ -129,11 +137,15 @@
|
|||
})()}
|
||||
|
||||
<div class="flex items-center gap-2">
|
||||
<Label for={field.key} class="text-sm font-medium">
|
||||
<Label for={field.key} class="flex items-center gap-1.5 text-sm font-medium">
|
||||
{field.label}
|
||||
|
||||
{#if field.isExperimental}
|
||||
<FlaskConical class="h-3.5 w-3.5 text-muted-foreground" />
|
||||
{/if}
|
||||
</Label>
|
||||
{#if isCustomRealTime}
|
||||
<ParameterSourceIndicator />
|
||||
<ChatSettingsParameterSourceIndicator />
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
|
|
@ -214,9 +226,13 @@
|
|||
for={field.key}
|
||||
class="cursor-pointer text-sm leading-none font-medium {isDisabled
|
||||
? 'text-muted-foreground'
|
||||
: ''}"
|
||||
: ''} flex items-center gap-1.5"
|
||||
>
|
||||
{field.label}
|
||||
|
||||
{#if field.isExperimental}
|
||||
<FlaskConical class="h-3.5 w-3.5 text-muted-foreground" />
|
||||
{/if}
|
||||
</label>
|
||||
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
<script lang="ts">
|
||||
import { Download, Upload } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import ConversationSelectionDialog from './ConversationSelectionDialog.svelte';
|
||||
import { DialogConversationSelection } from '$lib/components/app';
|
||||
import { DatabaseStore } from '$lib/stores/database';
|
||||
import type { ExportedConversations } from '$lib/types/database';
|
||||
import { createMessageCountMap } from '$lib/utils/conversation-utils';
|
||||
|
|
@ -236,7 +236,7 @@
|
|||
</div>
|
||||
</div>
|
||||
|
||||
<ConversationSelectionDialog
|
||||
<DialogConversationSelection
|
||||
conversations={availableConversations}
|
||||
{messageCountMap}
|
||||
mode="export"
|
||||
|
|
@ -245,7 +245,7 @@
|
|||
onConfirm={handleExportConfirm}
|
||||
/>
|
||||
|
||||
<ConversationSelectionDialog
|
||||
<DialogConversationSelection
|
||||
conversations={availableConversations}
|
||||
{messageCountMap}
|
||||
mode="import"
|
||||
|
|
@ -1,249 +0,0 @@
|
|||
<script lang="ts">
|
||||
import { Search, X } from '@lucide/svelte';
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import { ScrollArea } from '$lib/components/ui/scroll-area';
|
||||
import { SvelteSet } from 'svelte/reactivity';
|
||||
|
||||
interface Props {
|
||||
conversations: DatabaseConversation[];
|
||||
messageCountMap?: Map<string, number>;
|
||||
mode: 'export' | 'import';
|
||||
onCancel: () => void;
|
||||
onConfirm: (selectedConversations: DatabaseConversation[]) => void;
|
||||
open?: boolean;
|
||||
}
|
||||
|
||||
let {
|
||||
conversations,
|
||||
messageCountMap = new Map(),
|
||||
mode,
|
||||
onCancel,
|
||||
onConfirm,
|
||||
open = $bindable(false)
|
||||
}: Props = $props();
|
||||
|
||||
let searchQuery = $state('');
|
||||
let selectedIds = $state.raw<SvelteSet<string>>(new SvelteSet(conversations.map((c) => c.id)));
|
||||
let lastClickedId = $state<string | null>(null);
|
||||
|
||||
let filteredConversations = $derived(
|
||||
conversations.filter((conv) => {
|
||||
const name = conv.name || 'Untitled conversation';
|
||||
return name.toLowerCase().includes(searchQuery.toLowerCase());
|
||||
})
|
||||
);
|
||||
|
||||
let allSelected = $derived(
|
||||
filteredConversations.length > 0 &&
|
||||
filteredConversations.every((conv) => selectedIds.has(conv.id))
|
||||
);
|
||||
|
||||
let someSelected = $derived(
|
||||
filteredConversations.some((conv) => selectedIds.has(conv.id)) && !allSelected
|
||||
);
|
||||
|
||||
function toggleConversation(id: string, shiftKey: boolean = false) {
|
||||
const newSet = new SvelteSet(selectedIds);
|
||||
|
||||
if (shiftKey && lastClickedId !== null) {
|
||||
const lastIndex = filteredConversations.findIndex((c) => c.id === lastClickedId);
|
||||
const currentIndex = filteredConversations.findIndex((c) => c.id === id);
|
||||
|
||||
if (lastIndex !== -1 && currentIndex !== -1) {
|
||||
const start = Math.min(lastIndex, currentIndex);
|
||||
const end = Math.max(lastIndex, currentIndex);
|
||||
|
||||
const shouldSelect = !newSet.has(id);
|
||||
|
||||
for (let i = start; i <= end; i++) {
|
||||
if (shouldSelect) {
|
||||
newSet.add(filteredConversations[i].id);
|
||||
} else {
|
||||
newSet.delete(filteredConversations[i].id);
|
||||
}
|
||||
}
|
||||
|
||||
selectedIds = newSet;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (newSet.has(id)) {
|
||||
newSet.delete(id);
|
||||
} else {
|
||||
newSet.add(id);
|
||||
}
|
||||
|
||||
selectedIds = newSet;
|
||||
lastClickedId = id;
|
||||
}
|
||||
|
||||
function toggleAll() {
|
||||
if (allSelected) {
|
||||
const newSet = new SvelteSet(selectedIds);
|
||||
|
||||
filteredConversations.forEach((conv) => newSet.delete(conv.id));
|
||||
selectedIds = newSet;
|
||||
} else {
|
||||
const newSet = new SvelteSet(selectedIds);
|
||||
|
||||
filteredConversations.forEach((conv) => newSet.add(conv.id));
|
||||
selectedIds = newSet;
|
||||
}
|
||||
}
|
||||
|
||||
function handleConfirm() {
|
||||
const selected = conversations.filter((conv) => selectedIds.has(conv.id));
|
||||
onConfirm(selected);
|
||||
}
|
||||
|
||||
function handleCancel() {
|
||||
selectedIds = new SvelteSet(conversations.map((c) => c.id));
|
||||
searchQuery = '';
|
||||
lastClickedId = null;
|
||||
|
||||
onCancel();
|
||||
}
|
||||
|
||||
let previousOpen = $state(false);
|
||||
|
||||
$effect(() => {
|
||||
if (open && !previousOpen) {
|
||||
selectedIds = new SvelteSet(conversations.map((c) => c.id));
|
||||
searchQuery = '';
|
||||
lastClickedId = null;
|
||||
} else if (!open && previousOpen) {
|
||||
onCancel();
|
||||
}
|
||||
|
||||
previousOpen = open;
|
||||
});
|
||||
</script>
|
||||
|
||||
<Dialog.Root bind:open>
|
||||
<Dialog.Portal>
|
||||
<Dialog.Overlay class="z-[1000000]" />
|
||||
|
||||
<Dialog.Content class="z-[1000001] max-w-2xl">
|
||||
<Dialog.Header>
|
||||
<Dialog.Title>
|
||||
Select Conversations to {mode === 'export' ? 'Export' : 'Import'}
|
||||
</Dialog.Title>
|
||||
|
||||
<Dialog.Description>
|
||||
{#if mode === 'export'}
|
||||
Choose which conversations you want to export. Selected conversations will be downloaded
|
||||
as a JSON file.
|
||||
{:else}
|
||||
Choose which conversations you want to import. Selected conversations will be merged
|
||||
with your existing conversations.
|
||||
{/if}
|
||||
</Dialog.Description>
|
||||
</Dialog.Header>
|
||||
|
||||
<div class="space-y-4">
|
||||
<div class="relative">
|
||||
<Search class="absolute top-1/2 left-3 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
|
||||
|
||||
<Input bind:value={searchQuery} placeholder="Search conversations..." class="pr-9 pl-9" />
|
||||
|
||||
{#if searchQuery}
|
||||
<button
|
||||
class="absolute top-1/2 right-3 -translate-y-1/2 text-muted-foreground hover:text-foreground"
|
||||
onclick={() => (searchQuery = '')}
|
||||
type="button"
|
||||
>
|
||||
<X class="h-4 w-4" />
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="flex items-center justify-between text-sm text-muted-foreground">
|
||||
<span>
|
||||
{selectedIds.size} of {conversations.length} selected
|
||||
{#if searchQuery}
|
||||
({filteredConversations.length} shown)
|
||||
{/if}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div class="overflow-hidden rounded-md border">
|
||||
<ScrollArea class="h-[400px]">
|
||||
<table class="w-full">
|
||||
<thead class="sticky top-0 z-10 bg-muted">
|
||||
<tr class="border-b">
|
||||
<th class="w-12 p-3 text-left">
|
||||
<Checkbox
|
||||
checked={allSelected}
|
||||
indeterminate={someSelected}
|
||||
onCheckedChange={toggleAll}
|
||||
/>
|
||||
</th>
|
||||
|
||||
<th class="p-3 text-left text-sm font-medium">Conversation Name</th>
|
||||
|
||||
<th class="w-32 p-3 text-left text-sm font-medium">Messages</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{#if filteredConversations.length === 0}
|
||||
<tr>
|
||||
<td colspan="3" class="p-8 text-center text-sm text-muted-foreground">
|
||||
{#if searchQuery}
|
||||
No conversations found matching "{searchQuery}"
|
||||
{:else}
|
||||
No conversations available
|
||||
{/if}
|
||||
</td>
|
||||
</tr>
|
||||
{:else}
|
||||
{#each filteredConversations as conv (conv.id)}
|
||||
<tr
|
||||
class="cursor-pointer border-b transition-colors hover:bg-muted/50"
|
||||
onclick={(e) => toggleConversation(conv.id, e.shiftKey)}
|
||||
>
|
||||
<td class="p-3">
|
||||
<Checkbox
|
||||
checked={selectedIds.has(conv.id)}
|
||||
onclick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
toggleConversation(conv.id, e.shiftKey);
|
||||
}}
|
||||
/>
|
||||
</td>
|
||||
|
||||
<td class="p-3 text-sm">
|
||||
<div
|
||||
class="max-w-[17rem] truncate"
|
||||
title={conv.name || 'Untitled conversation'}
|
||||
>
|
||||
{conv.name || 'Untitled conversation'}
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td class="p-3 text-sm text-muted-foreground">
|
||||
{messageCountMap.get(conv.id) ?? 0}
|
||||
</td>
|
||||
</tr>
|
||||
{/each}
|
||||
{/if}
|
||||
</tbody>
|
||||
</table>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button variant="outline" onclick={handleCancel}>Cancel</Button>
|
||||
|
||||
<Button onclick={handleConfirm} disabled={selectedIds.size === 0}>
|
||||
{mode === 'export' ? 'Export' : 'Import'} ({selectedIds.size})
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</Dialog.Content>
|
||||
</Dialog.Portal>
|
||||
</Dialog.Root>
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
import { goto } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { Trash2 } from '@lucide/svelte';
|
||||
import { ChatSidebarConversationItem, ConfirmationDialog } from '$lib/components/app';
|
||||
import { ChatSidebarConversationItem, DialogConfirmation } from '$lib/components/app';
|
||||
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
|
||||
import * as Sidebar from '$lib/components/ui/sidebar';
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
|
|
@ -158,7 +158,7 @@
|
|||
<div class="bottom-0 z-10 bg-sidebar bg-sidebar/50 px-4 py-4 backdrop-blur-lg md:sticky"></div>
|
||||
</ScrollArea>
|
||||
|
||||
<ConfirmationDialog
|
||||
<DialogConfirmation
|
||||
bind:open={showDeleteDialog}
|
||||
title="Delete Conversation"
|
||||
description={selectedConversation
|
||||
|
|
|
|||
|
|
@ -0,0 +1,78 @@
|
|||
<script lang="ts">
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { ChatAttachmentPreview } from '$lib/components/app';
|
||||
import { formatFileSize } from '$lib/utils/file-preview';
|
||||
|
||||
interface Props {
|
||||
open: boolean;
|
||||
// Either an uploaded file or a stored attachment
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
// For uploaded files
|
||||
preview?: string;
|
||||
name?: string;
|
||||
type?: string;
|
||||
size?: number;
|
||||
textContent?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
open = $bindable(),
|
||||
uploadedFile,
|
||||
attachment,
|
||||
preview,
|
||||
name,
|
||||
type,
|
||||
size,
|
||||
textContent
|
||||
}: Props = $props();
|
||||
|
||||
let chatAttachmentPreviewRef: ChatAttachmentPreview | undefined = $state();
|
||||
|
||||
let displayName = $derived(uploadedFile?.name || attachment?.name || name || 'Unknown File');
|
||||
|
||||
let displayType = $derived(
|
||||
uploadedFile?.type ||
|
||||
(attachment?.type === 'imageFile'
|
||||
? 'image'
|
||||
: attachment?.type === 'textFile'
|
||||
? 'text'
|
||||
: attachment?.type === 'audioFile'
|
||||
? attachment.mimeType || 'audio'
|
||||
: attachment?.type === 'pdfFile'
|
||||
? 'application/pdf'
|
||||
: type || 'unknown')
|
||||
);
|
||||
|
||||
let displaySize = $derived(uploadedFile?.size || size);
|
||||
|
||||
$effect(() => {
|
||||
if (open && chatAttachmentPreviewRef) {
|
||||
chatAttachmentPreviewRef.reset();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<Dialog.Root bind:open>
|
||||
<Dialog.Content class="grid max-h-[90vh] max-w-5xl overflow-hidden sm:w-auto sm:max-w-6xl">
|
||||
<Dialog.Header>
|
||||
<Dialog.Title>{displayName}</Dialog.Title>
|
||||
<Dialog.Description>
|
||||
{displayType}
|
||||
{#if displaySize}
|
||||
• {formatFileSize(displaySize)}
|
||||
{/if}
|
||||
</Dialog.Description>
|
||||
</Dialog.Header>
|
||||
|
||||
<ChatAttachmentPreview
|
||||
bind:this={chatAttachmentPreviewRef}
|
||||
{uploadedFile}
|
||||
{attachment}
|
||||
{preview}
|
||||
{name}
|
||||
{type}
|
||||
{textContent}
|
||||
/>
|
||||
</Dialog.Content>
|
||||
</Dialog.Root>
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
<script lang="ts">
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { ChatAttachmentsViewAll } from '$lib/components/app';
|
||||
|
||||
interface Props {
|
||||
open?: boolean;
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
readonly?: boolean;
|
||||
onFileRemove?: (fileId: string) => void;
|
||||
imageHeight?: string;
|
||||
imageWidth?: string;
|
||||
imageClass?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
open = $bindable(false),
|
||||
uploadedFiles = [],
|
||||
attachments = [],
|
||||
readonly = false,
|
||||
onFileRemove,
|
||||
imageHeight = 'h-24',
|
||||
imageWidth = 'w-auto',
|
||||
imageClass = ''
|
||||
}: Props = $props();
|
||||
|
||||
let totalCount = $derived(uploadedFiles.length + attachments.length);
|
||||
</script>
|
||||
|
||||
<Dialog.Root bind:open>
|
||||
<Dialog.Portal>
|
||||
<Dialog.Overlay />
|
||||
|
||||
<Dialog.Content class="flex !max-h-[90vh] !max-w-6xl flex-col">
|
||||
<Dialog.Header>
|
||||
<Dialog.Title>All Attachments ({totalCount})</Dialog.Title>
|
||||
<Dialog.Description>View and manage all attached files</Dialog.Description>
|
||||
</Dialog.Header>
|
||||
|
||||
<ChatAttachmentsViewAll
|
||||
{uploadedFiles}
|
||||
{attachments}
|
||||
{readonly}
|
||||
{onFileRemove}
|
||||
{imageHeight}
|
||||
{imageWidth}
|
||||
{imageClass}
|
||||
/>
|
||||
</Dialog.Content>
|
||||
</Dialog.Portal>
|
||||
</Dialog.Root>
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
<script lang="ts">
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { ChatSettings } from '$lib/components/app';
|
||||
|
||||
interface Props {
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
open?: boolean;
|
||||
}
|
||||
|
||||
let { onOpenChange, open = false }: Props = $props();
|
||||
|
||||
let chatSettingsRef: ChatSettings | undefined = $state();
|
||||
|
||||
function handleClose() {
|
||||
onOpenChange?.(false);
|
||||
}
|
||||
|
||||
function handleSave() {
|
||||
onOpenChange?.(false);
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (open && chatSettingsRef) {
|
||||
chatSettingsRef.reset();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<Dialog.Root {open} onOpenChange={handleClose}>
|
||||
<Dialog.Content
|
||||
class="z-999999 flex h-[100dvh] max-h-[100dvh] min-h-[100dvh] flex-col gap-0 rounded-none p-0
|
||||
md:h-[64vh] md:max-h-[64vh] md:min-h-0 md:rounded-lg"
|
||||
style="max-width: 48rem;"
|
||||
>
|
||||
<ChatSettings bind:this={chatSettingsRef} onSave={handleSave} />
|
||||
</Dialog.Content>
|
||||
</Dialog.Root>
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
<script lang="ts">
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { ConversationSelection } from '$lib/components/app';
|
||||
|
||||
interface Props {
|
||||
conversations: DatabaseConversation[];
|
||||
messageCountMap?: Map<string, number>;
|
||||
mode: 'export' | 'import';
|
||||
onCancel: () => void;
|
||||
onConfirm: (selectedConversations: DatabaseConversation[]) => void;
|
||||
open?: boolean;
|
||||
}
|
||||
|
||||
let {
|
||||
conversations,
|
||||
messageCountMap = new Map(),
|
||||
mode,
|
||||
onCancel,
|
||||
onConfirm,
|
||||
open = $bindable(false)
|
||||
}: Props = $props();
|
||||
|
||||
let conversationSelectionRef: ConversationSelection | undefined = $state();
|
||||
|
||||
let previousOpen = $state(false);
|
||||
|
||||
$effect(() => {
|
||||
if (open && !previousOpen && conversationSelectionRef) {
|
||||
conversationSelectionRef.reset();
|
||||
} else if (!open && previousOpen) {
|
||||
onCancel();
|
||||
}
|
||||
|
||||
previousOpen = open;
|
||||
});
|
||||
</script>
|
||||
|
||||
<Dialog.Root bind:open>
|
||||
<Dialog.Portal>
|
||||
<Dialog.Overlay class="z-[1000000]" />
|
||||
|
||||
<Dialog.Content class="z-[1000001] max-w-2xl">
|
||||
<Dialog.Header>
|
||||
<Dialog.Title>
|
||||
Select Conversations to {mode === 'export' ? 'Export' : 'Import'}
|
||||
</Dialog.Title>
|
||||
<Dialog.Description>
|
||||
{#if mode === 'export'}
|
||||
Choose which conversations you want to export. Selected conversations will be downloaded
|
||||
as a JSON file.
|
||||
{:else}
|
||||
Choose which conversations you want to import. Selected conversations will be merged
|
||||
with your existing conversations.
|
||||
{/if}
|
||||
</Dialog.Description>
|
||||
</Dialog.Header>
|
||||
|
||||
<ConversationSelection
|
||||
bind:this={conversationSelectionRef}
|
||||
{conversations}
|
||||
{messageCountMap}
|
||||
{mode}
|
||||
{onCancel}
|
||||
{onConfirm}
|
||||
/>
|
||||
</Dialog.Content>
|
||||
</Dialog.Portal>
|
||||
</Dialog.Root>
|
||||
|
|
@ -1,56 +1,63 @@
|
|||
// Chat
|
||||
|
||||
export { default as ChatAttachmentPreview } from './chat/ChatAttachments/ChatAttachmentPreview.svelte';
|
||||
export { default as ChatAttachmentThumbnailFile } from './chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte';
|
||||
export { default as ChatAttachmentThumbnailImage } from './chat/ChatAttachments/ChatAttachmentThumbnailImage.svelte';
|
||||
export { default as ChatAttachmentsList } from './chat/ChatAttachments/ChatAttachmentsList.svelte';
|
||||
export { default as ChatAttachmentFilePreview } from './chat/ChatAttachments/ChatAttachmentFilePreview.svelte';
|
||||
export { default as ChatAttachmentImagePreview } from './chat/ChatAttachments/ChatAttachmentImagePreview.svelte';
|
||||
export { default as ChatAttachmentPreviewDialog } from './chat/ChatAttachments/ChatAttachmentPreviewDialog.svelte';
|
||||
export { default as ChatAttachmentsViewAllDialog } from './chat/ChatAttachments/ChatAttachmentsViewAllDialog.svelte';
|
||||
export { default as ChatAttachmentsViewAll } from './chat/ChatAttachments/ChatAttachmentsViewAll.svelte';
|
||||
|
||||
export { default as ChatForm } from './chat/ChatForm/ChatForm.svelte';
|
||||
export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte';
|
||||
export { default as ChatFormActions } from './chat/ChatForm/ChatFormActions.svelte';
|
||||
export { default as ChatFormActionFileAttachments } from './chat/ChatForm/ChatFormActionFileAttachments.svelte';
|
||||
export { default as ChatFormActionRecord } from './chat/ChatForm/ChatFormActionRecord.svelte';
|
||||
export { default as ChatFormModelSelector } from './chat/ChatForm/ChatFormModelSelector.svelte';
|
||||
export { default as ChatFormHelperText } from './chat/ChatForm/ChatFormHelperText.svelte';
|
||||
export { default as ChatFormActionFileAttachments } from './chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte';
|
||||
export { default as ChatFormActionRecord } from './chat/ChatForm/ChatFormActions/ChatFormActionRecord.svelte';
|
||||
export { default as ChatFormActions } from './chat/ChatForm/ChatFormActions/ChatFormActions.svelte';
|
||||
export { default as ChatFormFileInputInvisible } from './chat/ChatForm/ChatFormFileInputInvisible.svelte';
|
||||
export { default as ChatFormHelperText } from './chat/ChatForm/ChatFormHelperText.svelte';
|
||||
export { default as ChatFormModelSelector } from './chat/ChatForm/ChatFormModelSelector.svelte';
|
||||
export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte';
|
||||
|
||||
export { default as ChatMessage } from './chat/ChatMessages/ChatMessage.svelte';
|
||||
export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte';
|
||||
export { default as ChatMessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
|
||||
export { default as ChatMessageThinkingBlock } from './chat/ChatMessages/ChatMessageThinkingBlock.svelte';
|
||||
export { default as MessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
|
||||
|
||||
export { default as ChatProcessingInfo } from './chat/ChatProcessingInfo.svelte';
|
||||
|
||||
export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte';
|
||||
export { default as ChatScreenWarning } from './chat/ChatScreen/ChatScreenWarning.svelte';
|
||||
export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte';
|
||||
export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte';
|
||||
export { default as ChatScreenProcessingInfo } from './chat/ChatScreen/ChatScreenProcessingInfo.svelte';
|
||||
export { default as ChatScreenWarning } from './chat/ChatScreen/ChatScreenWarning.svelte';
|
||||
|
||||
export { default as ChatSettingsDialog } from './chat/ChatSettings/ChatSettingsDialog.svelte';
|
||||
export { default as ChatSettings } from './chat/ChatSettings/ChatSettings.svelte';
|
||||
export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte';
|
||||
export { default as ChatSettingsFields } from './chat/ChatSettings/ChatSettingsFields.svelte';
|
||||
export { default as ImportExportTab } from './chat/ChatSettings/ImportExportTab.svelte';
|
||||
export { default as ConversationSelectionDialog } from './chat/ChatSettings/ConversationSelectionDialog.svelte';
|
||||
export { default as ParameterSourceIndicator } from './chat/ChatSettings/ParameterSourceIndicator.svelte';
|
||||
export { default as ChatSettingsImportExportTab } from './chat/ChatSettings/ChatSettingsImportExportTab.svelte';
|
||||
export { default as ChatSettingsParameterSourceIndicator } from './chat/ChatSettings/ChatSettingsParameterSourceIndicator.svelte';
|
||||
|
||||
export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
|
||||
export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte';
|
||||
export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte';
|
||||
export { default as ChatErrorDialog } from './dialogs/ChatErrorDialog.svelte';
|
||||
export { default as EmptyFileAlertDialog } from './dialogs/EmptyFileAlertDialog.svelte';
|
||||
|
||||
export { default as ConversationTitleUpdateDialog } from './dialogs/ConversationTitleUpdateDialog.svelte';
|
||||
// Dialogs
|
||||
|
||||
export { default as DialogChatAttachmentPreview } from './dialogs/DialogChatAttachmentPreview.svelte';
|
||||
export { default as DialogChatAttachmentsViewAll } from './dialogs/DialogChatAttachmentsViewAll.svelte';
|
||||
export { default as DialogChatError } from './dialogs/DialogChatError.svelte';
|
||||
export { default as DialogChatSettings } from './dialogs/DialogChatSettings.svelte';
|
||||
export { default as DialogConfirmation } from './dialogs/DialogConfirmation.svelte';
|
||||
export { default as DialogConversationSelection } from './dialogs/DialogConversationSelection.svelte';
|
||||
export { default as DialogConversationTitleUpdate } from './dialogs/DialogConversationTitleUpdate.svelte';
|
||||
export { default as DialogEmptyFileAlert } from './dialogs/DialogEmptyFileAlert.svelte';
|
||||
|
||||
// Miscellanous
|
||||
|
||||
export { default as ActionButton } from './misc/ActionButton.svelte';
|
||||
export { default as ActionDropdown } from './misc/ActionDropdown.svelte';
|
||||
export { default as ConversationSelection } from './misc/ConversationSelection.svelte';
|
||||
export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte';
|
||||
|
||||
export { default as MarkdownContent } from './misc/MarkdownContent.svelte';
|
||||
|
||||
export { default as RemoveButton } from './misc/RemoveButton.svelte';
|
||||
|
||||
// Server
|
||||
|
||||
export { default as ServerStatus } from './server/ServerStatus.svelte';
|
||||
export { default as ServerErrorSplash } from './server/ServerErrorSplash.svelte';
|
||||
export { default as ServerLoadingSplash } from './server/ServerLoadingSplash.svelte';
|
||||
export { default as ServerInfo } from './server/ServerInfo.svelte';
|
||||
|
||||
// Shared components
|
||||
export { default as ActionButton } from './misc/ActionButton.svelte';
|
||||
export { default as ActionDropdown } from './misc/ActionDropdown.svelte';
|
||||
export { default as ConfirmationDialog } from './dialogs/ConfirmationDialog.svelte';
|
||||
|
|
|
|||
|
|
@ -0,0 +1,205 @@
|
|||
<script lang="ts">
|
||||
import { Search, X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import { ScrollArea } from '$lib/components/ui/scroll-area';
|
||||
import { SvelteSet } from 'svelte/reactivity';
|
||||
|
||||
interface Props {
|
||||
conversations: DatabaseConversation[];
|
||||
messageCountMap?: Map<string, number>;
|
||||
mode: 'export' | 'import';
|
||||
onCancel: () => void;
|
||||
onConfirm: (selectedConversations: DatabaseConversation[]) => void;
|
||||
}
|
||||
|
||||
let { conversations, messageCountMap = new Map(), mode, onCancel, onConfirm }: Props = $props();
|
||||
|
||||
let searchQuery = $state('');
|
||||
let selectedIds = $state.raw<SvelteSet<string>>(new SvelteSet(conversations.map((c) => c.id)));
|
||||
let lastClickedId = $state<string | null>(null);
|
||||
|
||||
let filteredConversations = $derived(
|
||||
conversations.filter((conv) => {
|
||||
const name = conv.name || 'Untitled conversation';
|
||||
return name.toLowerCase().includes(searchQuery.toLowerCase());
|
||||
})
|
||||
);
|
||||
|
||||
let allSelected = $derived(
|
||||
filteredConversations.length > 0 &&
|
||||
filteredConversations.every((conv) => selectedIds.has(conv.id))
|
||||
);
|
||||
|
||||
let someSelected = $derived(
|
||||
filteredConversations.some((conv) => selectedIds.has(conv.id)) && !allSelected
|
||||
);
|
||||
|
||||
function toggleConversation(id: string, shiftKey: boolean = false) {
|
||||
const newSet = new SvelteSet(selectedIds);
|
||||
|
||||
if (shiftKey && lastClickedId !== null) {
|
||||
const lastIndex = filteredConversations.findIndex((c) => c.id === lastClickedId);
|
||||
const currentIndex = filteredConversations.findIndex((c) => c.id === id);
|
||||
|
||||
if (lastIndex !== -1 && currentIndex !== -1) {
|
||||
const start = Math.min(lastIndex, currentIndex);
|
||||
const end = Math.max(lastIndex, currentIndex);
|
||||
|
||||
const shouldSelect = !newSet.has(id);
|
||||
|
||||
for (let i = start; i <= end; i++) {
|
||||
if (shouldSelect) {
|
||||
newSet.add(filteredConversations[i].id);
|
||||
} else {
|
||||
newSet.delete(filteredConversations[i].id);
|
||||
}
|
||||
}
|
||||
|
||||
selectedIds = newSet;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (newSet.has(id)) {
|
||||
newSet.delete(id);
|
||||
} else {
|
||||
newSet.add(id);
|
||||
}
|
||||
|
||||
selectedIds = newSet;
|
||||
lastClickedId = id;
|
||||
}
|
||||
|
||||
function toggleAll() {
|
||||
if (allSelected) {
|
||||
const newSet = new SvelteSet(selectedIds);
|
||||
|
||||
filteredConversations.forEach((conv) => newSet.delete(conv.id));
|
||||
selectedIds = newSet;
|
||||
} else {
|
||||
const newSet = new SvelteSet(selectedIds);
|
||||
|
||||
filteredConversations.forEach((conv) => newSet.add(conv.id));
|
||||
selectedIds = newSet;
|
||||
}
|
||||
}
|
||||
|
||||
function handleConfirm() {
|
||||
const selected = conversations.filter((conv) => selectedIds.has(conv.id));
|
||||
onConfirm(selected);
|
||||
}
|
||||
|
||||
function handleCancel() {
|
||||
selectedIds = new SvelteSet(conversations.map((c) => c.id));
|
||||
searchQuery = '';
|
||||
lastClickedId = null;
|
||||
|
||||
onCancel();
|
||||
}
|
||||
|
||||
export function reset() {
|
||||
selectedIds = new SvelteSet(conversations.map((c) => c.id));
|
||||
searchQuery = '';
|
||||
lastClickedId = null;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="space-y-4">
|
||||
<div class="relative">
|
||||
<Search class="absolute top-1/2 left-3 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
|
||||
|
||||
<Input bind:value={searchQuery} placeholder="Search conversations..." class="pr-9 pl-9" />
|
||||
|
||||
{#if searchQuery}
|
||||
<button
|
||||
class="absolute top-1/2 right-3 -translate-y-1/2 text-muted-foreground hover:text-foreground"
|
||||
onclick={() => (searchQuery = '')}
|
||||
type="button"
|
||||
>
|
||||
<X class="h-4 w-4" />
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="flex items-center justify-between text-sm text-muted-foreground">
|
||||
<span>
|
||||
{selectedIds.size} of {conversations.length} selected
|
||||
{#if searchQuery}
|
||||
({filteredConversations.length} shown)
|
||||
{/if}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div class="overflow-hidden rounded-md border">
|
||||
<ScrollArea class="h-[400px]">
|
||||
<table class="w-full">
|
||||
<thead class="sticky top-0 z-10 bg-muted">
|
||||
<tr class="border-b">
|
||||
<th class="w-12 p-3 text-left">
|
||||
<Checkbox
|
||||
checked={allSelected}
|
||||
indeterminate={someSelected}
|
||||
onCheckedChange={toggleAll}
|
||||
/>
|
||||
</th>
|
||||
|
||||
<th class="p-3 text-left text-sm font-medium">Conversation Name</th>
|
||||
|
||||
<th class="w-32 p-3 text-left text-sm font-medium">Messages</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{#if filteredConversations.length === 0}
|
||||
<tr>
|
||||
<td colspan="3" class="p-8 text-center text-sm text-muted-foreground">
|
||||
{#if searchQuery}
|
||||
No conversations found matching "{searchQuery}"
|
||||
{:else}
|
||||
No conversations available
|
||||
{/if}
|
||||
</td>
|
||||
</tr>
|
||||
{:else}
|
||||
{#each filteredConversations as conv (conv.id)}
|
||||
<tr
|
||||
class="cursor-pointer border-b transition-colors hover:bg-muted/50"
|
||||
onclick={(e) => toggleConversation(conv.id, e.shiftKey)}
|
||||
>
|
||||
<td class="p-3">
|
||||
<Checkbox
|
||||
checked={selectedIds.has(conv.id)}
|
||||
onclick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
toggleConversation(conv.id, e.shiftKey);
|
||||
}}
|
||||
/>
|
||||
</td>
|
||||
|
||||
<td class="p-3 text-sm">
|
||||
<div class="max-w-[17rem] truncate" title={conv.name || 'Untitled conversation'}>
|
||||
{conv.name || 'Untitled conversation'}
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td class="p-3 text-sm text-muted-foreground">
|
||||
{messageCountMap.get(conv.id) ?? 0}
|
||||
</td>
|
||||
</tr>
|
||||
{/each}
|
||||
{/if}
|
||||
</tbody>
|
||||
</table>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
|
||||
<div class="flex justify-end gap-2">
|
||||
<Button variant="outline" onclick={handleCancel}>Cancel</Button>
|
||||
|
||||
<Button onclick={handleConfirm} disabled={selectedIds.size === 0}>
|
||||
{mode === 'export' ? 'Export' : 'Import'} ({selectedIds.size})
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -38,7 +38,8 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
|||
max_tokens: -1,
|
||||
custom: '', // custom json-stringified object
|
||||
// experimental features
|
||||
pyInterpreterEnabled: false
|
||||
pyInterpreterEnabled: false,
|
||||
enableContinueGeneration: false
|
||||
};
|
||||
|
||||
export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
|
|
@ -96,5 +97,7 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
|||
modelSelectorEnabled:
|
||||
'Enable the model selector in the chat input to choose the inference model. Sends the associated model field in API requests.',
|
||||
pyInterpreterEnabled:
|
||||
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.'
|
||||
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.',
|
||||
enableContinueGeneration:
|
||||
'Enable "Continue" button for assistant messages. Currently works only with non-reasoning models.'
|
||||
};
|
||||
|
|
|
|||
|
|
@ -312,7 +312,6 @@ export class ChatService {
|
|||
let aggregatedContent = '';
|
||||
let fullReasoningContent = '';
|
||||
let aggregatedToolCalls: ApiChatCompletionToolCall[] = [];
|
||||
let hasReceivedData = false;
|
||||
let lastTimings: ChatMessageTimings | undefined;
|
||||
let streamFinished = false;
|
||||
let modelEmitted = false;
|
||||
|
|
@ -352,8 +351,6 @@ export class ChatService {
|
|||
return;
|
||||
}
|
||||
|
||||
hasReceivedData = true;
|
||||
|
||||
if (!abortSignal?.aborted) {
|
||||
onToolCallChunk?.(serializedToolCalls);
|
||||
}
|
||||
|
|
@ -415,7 +412,6 @@ export class ChatService {
|
|||
|
||||
if (content) {
|
||||
finalizeOpenToolCallBatch();
|
||||
hasReceivedData = true;
|
||||
aggregatedContent += content;
|
||||
if (!abortSignal?.aborted) {
|
||||
onChunk?.(content);
|
||||
|
|
@ -424,7 +420,6 @@ export class ChatService {
|
|||
|
||||
if (reasoningContent) {
|
||||
finalizeOpenToolCallBatch();
|
||||
hasReceivedData = true;
|
||||
fullReasoningContent += reasoningContent;
|
||||
if (!abortSignal?.aborted) {
|
||||
onReasoningChunk?.(reasoningContent);
|
||||
|
|
@ -446,15 +441,6 @@ export class ChatService {
|
|||
if (streamFinished) {
|
||||
finalizeOpenToolCallBatch();
|
||||
|
||||
if (
|
||||
!hasReceivedData &&
|
||||
aggregatedContent.length === 0 &&
|
||||
aggregatedToolCalls.length === 0
|
||||
) {
|
||||
const noResponseError = new Error('No response received from server. Please try again.');
|
||||
throw noResponseError;
|
||||
}
|
||||
|
||||
const finalToolCalls =
|
||||
aggregatedToolCalls.length > 0 ? JSON.stringify(aggregatedToolCalls) : undefined;
|
||||
|
||||
|
|
|
|||
|
|
@ -1486,6 +1486,10 @@ class ChatStore {
|
|||
timestamp: Date.now()
|
||||
});
|
||||
|
||||
// Ensure currNode points to the edited message to maintain correct path
|
||||
await DatabaseStore.updateCurrentNode(this.activeConversation.id, messageToEdit.id);
|
||||
this.activeConversation.currNode = messageToEdit.id;
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, {
|
||||
content: newContent,
|
||||
timestamp: Date.now()
|
||||
|
|
@ -1499,6 +1503,69 @@ class ChatStore {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Edits a user message and preserves all responses below
|
||||
* Updates the message content in-place without deleting or regenerating responses
|
||||
*
|
||||
* **Use Case**: When you want to fix a typo or rephrase a question without losing the assistant's response
|
||||
*
|
||||
* **Important Behavior:**
|
||||
* - Does NOT create a branch (unlike editMessageWithBranching)
|
||||
* - Does NOT regenerate assistant responses
|
||||
* - Only updates the user message content in the database
|
||||
* - Preserves the entire conversation tree below the edited message
|
||||
* - Updates conversation title if this is the first user message
|
||||
*
|
||||
* @param messageId - The ID of the user message to edit
|
||||
* @param newContent - The new content for the message
|
||||
*/
|
||||
async editUserMessagePreserveResponses(messageId: string, newContent: string): Promise<void> {
|
||||
if (!this.activeConversation) return;
|
||||
|
||||
try {
|
||||
const messageIndex = this.findMessageIndex(messageId);
|
||||
if (messageIndex === -1) {
|
||||
console.error('Message not found for editing');
|
||||
return;
|
||||
}
|
||||
|
||||
const messageToEdit = this.activeMessages[messageIndex];
|
||||
if (messageToEdit.role !== 'user') {
|
||||
console.error('Only user messages can be edited with this method');
|
||||
return;
|
||||
}
|
||||
|
||||
// Simply update the message content in-place
|
||||
await DatabaseStore.updateMessage(messageId, {
|
||||
content: newContent,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, {
|
||||
content: newContent,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
|
||||
// Check if first user message for title update
|
||||
const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id);
|
||||
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
|
||||
const isFirstUserMessage =
|
||||
rootMessage && messageToEdit.parent === rootMessage.id && messageToEdit.role === 'user';
|
||||
|
||||
if (isFirstUserMessage && newContent.trim()) {
|
||||
await this.updateConversationTitleWithConfirmation(
|
||||
this.activeConversation.id,
|
||||
newContent.trim(),
|
||||
this.titleUpdateConfirmationCallback
|
||||
);
|
||||
}
|
||||
|
||||
this.updateConversationTimestamp();
|
||||
} catch (error) {
|
||||
console.error('Failed to edit user message:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Edits a message by creating a new branch with the edited content
|
||||
* @param messageId - The ID of the message to edit
|
||||
|
|
@ -1696,6 +1763,200 @@ class ChatStore {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Continues generation for an existing assistant message
|
||||
* @param messageId - The ID of the assistant message to continue
|
||||
*/
|
||||
async continueAssistantMessage(messageId: string): Promise<void> {
|
||||
if (!this.activeConversation || this.isLoading) return;
|
||||
|
||||
try {
|
||||
const messageIndex = this.findMessageIndex(messageId);
|
||||
if (messageIndex === -1) {
|
||||
console.error('Message not found for continuation');
|
||||
return;
|
||||
}
|
||||
|
||||
const messageToContinue = this.activeMessages[messageIndex];
|
||||
if (messageToContinue.role !== 'assistant') {
|
||||
console.error('Only assistant messages can be continued');
|
||||
return;
|
||||
}
|
||||
|
||||
// Race condition protection: Check if this specific conversation is already loading
|
||||
// This prevents multiple rapid clicks on "Continue" from creating concurrent operations
|
||||
if (this.isConversationLoading(this.activeConversation.id)) {
|
||||
console.warn('Continuation already in progress for this conversation');
|
||||
return;
|
||||
}
|
||||
|
||||
this.errorDialogState = null;
|
||||
this.setConversationLoading(this.activeConversation.id, true);
|
||||
this.clearConversationStreaming(this.activeConversation.id);
|
||||
|
||||
// IMPORTANT: Fetch the latest content from the database to ensure we have
|
||||
// the most up-to-date content, especially after a stopped generation
|
||||
// This prevents issues where the in-memory state might be stale
|
||||
const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id);
|
||||
const dbMessage = allMessages.find((m) => m.id === messageId);
|
||||
|
||||
if (!dbMessage) {
|
||||
console.error('Message not found in database for continuation');
|
||||
this.setConversationLoading(this.activeConversation.id, false);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Use content from database as the source of truth
|
||||
const originalContent = dbMessage.content;
|
||||
const originalThinking = dbMessage.thinking || '';
|
||||
|
||||
// Get conversation context up to (but not including) the message to continue
|
||||
const conversationContext = this.activeMessages.slice(0, messageIndex);
|
||||
|
||||
const contextWithContinue = [
|
||||
...conversationContext.map((msg) => {
|
||||
if ('id' in msg && 'convId' in msg && 'timestamp' in msg) {
|
||||
return msg as DatabaseMessage & { extra?: DatabaseMessageExtra[] };
|
||||
}
|
||||
return msg as ApiChatMessageData;
|
||||
}),
|
||||
{
|
||||
role: 'assistant' as const,
|
||||
content: originalContent
|
||||
}
|
||||
];
|
||||
|
||||
let appendedContent = '';
|
||||
let appendedThinking = '';
|
||||
let hasReceivedContent = false;
|
||||
|
||||
await chatService.sendMessage(
|
||||
contextWithContinue,
|
||||
{
|
||||
...this.getApiOptions(),
|
||||
|
||||
onChunk: (chunk: string) => {
|
||||
hasReceivedContent = true;
|
||||
appendedContent += chunk;
|
||||
// Preserve originalContent exactly as-is, including any trailing whitespace
|
||||
// The concatenation naturally preserves any whitespace at the end of originalContent
|
||||
const fullContent = originalContent + appendedContent;
|
||||
|
||||
this.setConversationStreaming(
|
||||
messageToContinue.convId,
|
||||
fullContent,
|
||||
messageToContinue.id
|
||||
);
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, {
|
||||
content: fullContent
|
||||
});
|
||||
},
|
||||
|
||||
onReasoningChunk: (reasoningChunk: string) => {
|
||||
hasReceivedContent = true;
|
||||
appendedThinking += reasoningChunk;
|
||||
|
||||
const fullThinking = originalThinking + appendedThinking;
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, {
|
||||
thinking: fullThinking
|
||||
});
|
||||
},
|
||||
|
||||
onComplete: async (
|
||||
finalContent?: string,
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings
|
||||
) => {
|
||||
const fullContent = originalContent + (finalContent || appendedContent);
|
||||
const fullThinking = originalThinking + (reasoningContent || appendedThinking);
|
||||
|
||||
const updateData: {
|
||||
content: string;
|
||||
thinking: string;
|
||||
timestamp: number;
|
||||
timings?: ChatMessageTimings;
|
||||
} = {
|
||||
content: fullContent,
|
||||
thinking: fullThinking,
|
||||
timestamp: Date.now(),
|
||||
timings: timings
|
||||
};
|
||||
|
||||
await DatabaseStore.updateMessage(messageToContinue.id, updateData);
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, updateData);
|
||||
|
||||
this.updateConversationTimestamp();
|
||||
|
||||
this.setConversationLoading(messageToContinue.convId, false);
|
||||
this.clearConversationStreaming(messageToContinue.convId);
|
||||
slotsService.clearConversationState(messageToContinue.convId);
|
||||
},
|
||||
|
||||
onError: async (error: Error) => {
|
||||
if (this.isAbortError(error)) {
|
||||
// User cancelled - save partial continuation if any content was received
|
||||
if (hasReceivedContent && appendedContent) {
|
||||
const partialContent = originalContent + appendedContent;
|
||||
const partialThinking = originalThinking + appendedThinking;
|
||||
|
||||
await DatabaseStore.updateMessage(messageToContinue.id, {
|
||||
content: partialContent,
|
||||
thinking: partialThinking,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, {
|
||||
content: partialContent,
|
||||
thinking: partialThinking,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
}
|
||||
|
||||
this.setConversationLoading(messageToContinue.convId, false);
|
||||
this.clearConversationStreaming(messageToContinue.convId);
|
||||
slotsService.clearConversationState(messageToContinue.convId);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Non-abort error - rollback to original content
|
||||
console.error('Continue generation error:', error);
|
||||
|
||||
// Rollback: Restore original content in UI
|
||||
this.updateMessageAtIndex(messageIndex, {
|
||||
content: originalContent,
|
||||
thinking: originalThinking
|
||||
});
|
||||
|
||||
// Ensure database has original content (in case of partial writes)
|
||||
await DatabaseStore.updateMessage(messageToContinue.id, {
|
||||
content: originalContent,
|
||||
thinking: originalThinking
|
||||
});
|
||||
|
||||
this.setConversationLoading(messageToContinue.convId, false);
|
||||
this.clearConversationStreaming(messageToContinue.convId);
|
||||
slotsService.clearConversationState(messageToContinue.convId);
|
||||
|
||||
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
|
||||
this.showErrorDialog(dialogType, error.message);
|
||||
}
|
||||
},
|
||||
messageToContinue.convId
|
||||
);
|
||||
} catch (error) {
|
||||
if (this.isAbortError(error)) return;
|
||||
console.error('Failed to continue message:', error);
|
||||
if (this.activeConversation) {
|
||||
this.setConversationLoading(this.activeConversation.id, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Public methods for accessing per-conversation states
|
||||
*/
|
||||
|
|
@ -1743,8 +2004,11 @@ export const refreshActiveMessages = chatStore.refreshActiveMessages.bind(chatSt
|
|||
export const navigateToSibling = chatStore.navigateToSibling.bind(chatStore);
|
||||
export const editAssistantMessage = chatStore.editAssistantMessage.bind(chatStore);
|
||||
export const editMessageWithBranching = chatStore.editMessageWithBranching.bind(chatStore);
|
||||
export const editUserMessagePreserveResponses =
|
||||
chatStore.editUserMessagePreserveResponses.bind(chatStore);
|
||||
export const regenerateMessageWithBranching =
|
||||
chatStore.regenerateMessageWithBranching.bind(chatStore);
|
||||
export const continueAssistantMessage = chatStore.continueAssistantMessage.bind(chatStore);
|
||||
export const deleteMessage = chatStore.deleteMessage.bind(chatStore);
|
||||
export const getDeletionInfo = chatStore.getDeletionInfo.bind(chatStore);
|
||||
export const updateConversationName = chatStore.updateConversationName.bind(chatStore);
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ export interface SettingsFieldConfig {
|
|||
key: string;
|
||||
label: string;
|
||||
type: 'input' | 'textarea' | 'checkbox' | 'select';
|
||||
isExperimental?: boolean;
|
||||
help?: string;
|
||||
options?: Array<{ value: string; label: string; icon?: typeof import('@lucide/svelte').Icon }>;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
<script lang="ts">
|
||||
import '../app.css';
|
||||
import { page } from '$app/state';
|
||||
import { ChatSidebar, ConversationTitleUpdateDialog } from '$lib/components/app';
|
||||
import { ChatSidebar, DialogConversationTitleUpdate } from '$lib/components/app';
|
||||
import {
|
||||
activeMessages,
|
||||
isLoading,
|
||||
|
|
@ -150,7 +150,7 @@
|
|||
|
||||
<Toaster richColors />
|
||||
|
||||
<ConversationTitleUpdateDialog
|
||||
<DialogConversationTitleUpdate
|
||||
bind:open={titleUpdateDialogOpen}
|
||||
currentTitle={titleUpdateCurrentTitle}
|
||||
newTitle={titleUpdateNewTitle}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
<script module>
|
||||
import { defineMeta } from '@storybook/addon-svelte-csf';
|
||||
import { ChatSettings } from '$lib/components/app';
|
||||
import { fn } from 'storybook/test';
|
||||
|
||||
const { Story } = defineMeta({
|
||||
title: 'Components/ChatSettings',
|
||||
component: ChatSettings,
|
||||
parameters: {
|
||||
layout: 'fullscreen'
|
||||
},
|
||||
args: {
|
||||
onClose: fn(),
|
||||
onSave: fn()
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<Story name="Default" />
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
<script module>
|
||||
import { defineMeta } from '@storybook/addon-svelte-csf';
|
||||
import { ChatSettingsDialog } from '$lib/components/app';
|
||||
import { fn } from 'storybook/test';
|
||||
|
||||
const { Story } = defineMeta({
|
||||
title: 'Components/ChatSettingsDialog',
|
||||
component: ChatSettingsDialog,
|
||||
parameters: {
|
||||
layout: 'fullscreen'
|
||||
},
|
||||
argTypes: {
|
||||
open: {
|
||||
control: 'boolean',
|
||||
description: 'Whether the dialog is open'
|
||||
}
|
||||
},
|
||||
args: {
|
||||
onOpenChange: fn()
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<Story name="Open" args={{ open: true }} />
|
||||
|
||||
<Story name="Closed" args={{ open: false }} />
|
||||
Loading…
Reference in New Issue