model : Apertus model implementation (#15852)
* First attempt * No permute during convert (fixes qk tensors), proper norm application. * RoPE = NeoX * Coherence! * Migrate xielu params from tensors to hyperparameters * Simple CUDA kernel * Revert stupid LLM refactorings * Chat template support * configchecker / flake8 errors * Reorder unary.cu * I do conclude that LLMs are, in fact, stupid. * Fix after merge * Final newline * Make xIELU an UNARY_OP * Final newline * Correctly account for parameter shift * Argh. * Update ggml/src/ggml-cpu/unary-ops.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Refactor: remove unused methods, inline and factorize softplus, add const modifiers * Revert CUDA changes, implement xIELU as a separate OP * Pesky newline * Add float2half / half2float for F16 inputs/outputs * CUDA variants, attempt 2 * Actually, attempt 3 * Update ggml/src/ggml-cuda/unary.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Missing convert header * Proper formula and reference for xIELU in the comments. * Modify unary-ops.cpp to add the functor-based logic besides the template system to retain optimizations * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Add tensor mappings for Apertus to global list instead * Fix lazy on scalars * Update ggml/src/ggml-cuda/unary.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Add comment about the constraints on positive/negative alpha * Change `softplus` to `ggml_softplus` --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Johannes Gäßler <johannesg@5d6.de> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
parent
91a2a56556
commit
34fcc5a4ac
|
|
@ -75,6 +75,35 @@ bool common_chat_msg_parser::add_tool_calls(const json & arr) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::add_tool_call_short_form(const json & tool_call) {
|
||||
if (!tool_call.is_object() || tool_call.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get the tool name (the single key in the object)
|
||||
auto it = tool_call.begin();
|
||||
std::string name = it.key();
|
||||
|
||||
if (name.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get the arguments (the nested object)
|
||||
const json & args_json = it.value();
|
||||
std::string arguments = "";
|
||||
|
||||
if (args_json.is_object()) {
|
||||
arguments = args_json.dump();
|
||||
} else if (args_json.is_string()) {
|
||||
arguments = args_json;
|
||||
} else if (!args_json.is_null()) {
|
||||
// For other types, convert to string representation
|
||||
arguments = args_json.dump();
|
||||
}
|
||||
|
||||
return add_tool_call(name, "", arguments);
|
||||
}
|
||||
void common_chat_msg_parser::finish() {
|
||||
if (!is_partial_ && pos_ != input_.size()) {
|
||||
throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));
|
||||
|
|
|
|||
|
|
@ -64,6 +64,9 @@ class common_chat_msg_parser {
|
|||
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
|
||||
bool add_tool_calls(const nlohmann::ordered_json & arr);
|
||||
|
||||
// Adds a tool call using the short form: { "tool_name": { "arg1": val, "arg2": val } }
|
||||
bool add_tool_call_short_form(const nlohmann::ordered_json & tool_call);
|
||||
|
||||
void finish();
|
||||
|
||||
bool consume_spaces();
|
||||
|
|
|
|||
110
common/chat.cpp
110
common/chat.cpp
|
|
@ -638,6 +638,7 @@ const char * common_chat_format_name(common_chat_format format) {
|
|||
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
|
||||
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
|
||||
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
|
||||
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
|
|
@ -801,6 +802,7 @@ static std::string apply(
|
|||
}
|
||||
tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt;
|
||||
tmpl_inputs.extra_context = inputs.extra_context;
|
||||
tmpl_inputs.extra_context["enable_thinking"] = inputs.enable_thinking;
|
||||
if (additional_context) {
|
||||
tmpl_inputs.extra_context.merge_patch(*additional_context);
|
||||
}
|
||||
|
|
@ -1264,6 +1266,75 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_
|
|||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// Generate the prompt using the apply() function with the template
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_APERTUS;
|
||||
|
||||
// Handle thinking tags appropriately based on inputs.enable_thinking
|
||||
if (string_ends_with(data.prompt, "<|inner_prefix|>")) {
|
||||
if (!inputs.enable_thinking) {
|
||||
data.prompt += "<|inner_suffix|>";
|
||||
} else {
|
||||
data.thinking_forced_open = true;
|
||||
}
|
||||
}
|
||||
|
||||
// When tools are present, build grammar for the <|tools_prefix|> format
|
||||
if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = true;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
schemas.push_back({
|
||||
{ "type", "object" },
|
||||
{ "properties",
|
||||
{
|
||||
{ function.at("name"), function.at("parameters") }
|
||||
} },
|
||||
{ "required", json::array({ function.at("name") }) },
|
||||
});
|
||||
});
|
||||
auto schema = json{
|
||||
{ "type", "array" },
|
||||
{ "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } },
|
||||
{ "minItems", 1 },
|
||||
};
|
||||
if (!inputs.parallel_tool_calls) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root",
|
||||
std::string(data.thinking_forced_open ? "( \"<|inner_suffix|>\" space )? " : "") +
|
||||
"\"<|tools_prefix|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tools_suffix|>\"");
|
||||
});
|
||||
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
|
||||
// If thinking_forced_open, then we capture the <|inner_suffix|> tag in the grammar,
|
||||
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
|
||||
std::string(data.thinking_forced_open ?
|
||||
"[\\s\\S]*?(<\\|inner_suffix\\|>\\s*)" :
|
||||
"(?:<\\|inner_prefix\\|>[\\s\\S]*?<\\|inner_suffix\\|>\\s*)?") +
|
||||
"(<\\|tools_prefix\\|>)[\\s\\S]*" });
|
||||
data.preserved_tokens = {
|
||||
"<|system_start|>",
|
||||
"<|system_end|>",
|
||||
"<|developer_start|>",
|
||||
"<|developer_end|>",
|
||||
"<|user_start|>",
|
||||
"<|user_end|>",
|
||||
"<|assistant_start|>",
|
||||
"<|assistant_end|>",
|
||||
"<|inner_prefix|>",
|
||||
"<|inner_suffix|>",
|
||||
"<|tools_prefix|>",
|
||||
"<|tools_suffix|>",
|
||||
};
|
||||
}
|
||||
return data;
|
||||
}
|
||||
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
|
|
@ -2323,6 +2394,37 @@ static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
|
|||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags
|
||||
builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>");
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// Look for tool calls
|
||||
static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>"));
|
||||
if (auto res = builder.try_find_regex(tool_call_regex)) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
auto tool_calls_data = builder.consume_json();
|
||||
if (tool_calls_data.json.is_array()) {
|
||||
builder.consume_spaces();
|
||||
if (!builder.try_consume_literal("<|tools_suffix|>")) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
for (const auto & value : tool_calls_data.json) {
|
||||
if (value.is_object()) {
|
||||
builder.add_tool_call_short_form(value);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
}
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags first - this handles the main reasoning content
|
||||
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
|
||||
|
|
@ -2567,6 +2669,11 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
return common_chat_params_init_nemotron_v2(tmpl, params);
|
||||
}
|
||||
|
||||
// Apertus format detection
|
||||
if (src.find("<|system_start|>") != std::string::npos && src.find("<|tools_prefix|>") != std::string::npos) {
|
||||
return common_chat_params_init_apertus(tmpl, params);
|
||||
}
|
||||
|
||||
// Use generic handler when mixing tools + JSON schema.
|
||||
// TODO: support that mix in handlers below.
|
||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||
|
|
@ -2734,6 +2841,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
|||
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
|
||||
common_chat_parse_nemotron_v2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_APERTUS:
|
||||
common_chat_parse_apertus(builder);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -114,6 +114,7 @@ enum common_chat_format {
|
|||
COMMON_CHAT_FORMAT_GPT_OSS,
|
||||
COMMON_CHAT_FORMAT_SEED_OSS,
|
||||
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||
COMMON_CHAT_FORMAT_APERTUS,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
|
|
|||
|
|
@ -8945,6 +8945,43 @@ class SmallThinkerModel(TextModel):
|
|||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("ApertusForCausalLM")
|
||||
class ApertusModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.APERTUS
|
||||
undo_permute = False
|
||||
|
||||
_alpha_n = {}
|
||||
_alpha_p = {}
|
||||
_beta = {}
|
||||
_eps = {}
|
||||
|
||||
def modify_tensors(self, data_torch, name, bid):
|
||||
# Handle xIELU activation parameters
|
||||
n_layers = self.hparams["num_hidden_layers"]
|
||||
if name.endswith(".act_fn.alpha_n"):
|
||||
self._alpha_n[bid] = data_torch.to("cpu").float().item()
|
||||
if (len(self._alpha_n) == n_layers):
|
||||
self.gguf_writer.add_xielu_alpha_n([self._alpha_n[k] for k in sorted(self._alpha_n)])
|
||||
return []
|
||||
if name.endswith(".act_fn.alpha_p"):
|
||||
self._alpha_p[bid] = data_torch.to("cpu").float().item()
|
||||
if (len(self._alpha_p) == n_layers):
|
||||
self.gguf_writer.add_xielu_alpha_p([self._alpha_p[k] for k in sorted(self._alpha_p)])
|
||||
return []
|
||||
if name.endswith(".act_fn.beta"):
|
||||
self._beta[bid] = data_torch.to("cpu").float().item()
|
||||
if (len(self._beta) == n_layers):
|
||||
self.gguf_writer.add_xielu_beta([self._beta[k] for k in sorted(self._beta)])
|
||||
return []
|
||||
if name.endswith(".act_fn.eps"):
|
||||
self._eps[bid] = data_torch.to("cpu").float().item()
|
||||
if (len(self._eps) == n_layers):
|
||||
self.gguf_writer.add_xielu_eps([self._eps[k] for k in sorted(self._eps)])
|
||||
return []
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
class MistralModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
model_name = "Mistral"
|
||||
|
|
@ -9112,7 +9149,7 @@ class LazyTorchTensor(gguf.LazyBase):
|
|||
def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
|
||||
dtype = cls._dtype_str_map[st_slice.get_dtype()]
|
||||
shape: tuple[int, ...] = tuple(st_slice.get_shape())
|
||||
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
|
||||
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:])
|
||||
return cast(torch.Tensor, lazy)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -576,6 +576,7 @@ extern "C" {
|
|||
GGML_UNARY_OP_HARDSIGMOID,
|
||||
GGML_UNARY_OP_EXP,
|
||||
GGML_UNARY_OP_GELU_ERF,
|
||||
GGML_UNARY_OP_XIELU,
|
||||
|
||||
GGML_UNARY_OP_COUNT,
|
||||
};
|
||||
|
|
@ -1150,6 +1151,18 @@ extern "C" {
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// xIELU activation function
|
||||
// x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
|
||||
// where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
|
||||
// that constrain the positive and negative source alpha values respectively
|
||||
GGML_API struct ggml_tensor * ggml_xielu(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float alpha_n,
|
||||
float alpha_p,
|
||||
float beta,
|
||||
float eps);
|
||||
|
||||
// gated linear unit ops
|
||||
// A: n columns, r rows,
|
||||
// result is n / 2 columns, r rows,
|
||||
|
|
|
|||
|
|
@ -2187,6 +2187,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_UNARY_OP_GELU_ERF:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_XIELU:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -8637,7 +8637,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
// n_head
|
||||
for (int h = ih0; h < ih1; ++h) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
||||
const float dt_soft_plus = ggml_softplus(dt[h]);
|
||||
const float dA = expf(dt_soft_plus * A[h]);
|
||||
const int g = h / (nh / ng); // repeat_interleave
|
||||
|
||||
|
|
@ -8734,7 +8734,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
// n_head
|
||||
for (int h = ih0; h < ih1; ++h) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
||||
const float dt_soft_plus = ggml_softplus(dt[h]);
|
||||
const int g = h / (nh / ng); // repeat_interleave
|
||||
|
||||
// dim
|
||||
|
|
@ -8997,6 +8997,10 @@ void ggml_compute_forward_unary(
|
|||
{
|
||||
ggml_compute_forward_exp(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_XIELU:
|
||||
{
|
||||
ggml_compute_forward_xielu(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
|
|||
|
|
@ -52,6 +52,15 @@ static inline float op_sqrt(float x) {
|
|||
return sqrtf(x);
|
||||
}
|
||||
|
||||
static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) {
|
||||
if (x > 0.0f) {
|
||||
return alpha_p * x * x + beta * x;
|
||||
} else {
|
||||
const float min_x_eps = fminf(x, eps);
|
||||
return (expm1f(min_x_eps) - x) * alpha_n + beta * x;
|
||||
}
|
||||
}
|
||||
|
||||
static inline float op_sin(float x) {
|
||||
return sinf(x);
|
||||
}
|
||||
|
|
@ -121,6 +130,86 @@ static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
|
|||
}
|
||||
}
|
||||
|
||||
template <float (*op)(float, ggml_tensor *)>
|
||||
static void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
/* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
|
||||
apply_unary_op<op, float, float>(params, dst);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
|
||||
apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
|
||||
apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
|
||||
apply_unary_op<op, ggml_bf16_t, float>(params, dst);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||
apply_unary_op<op, ggml_fp16_t, float>(params, dst);
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
|
||||
ggml_type_name(dst->type), ggml_type_name(src0->type));
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
// Extend vec_unary_op to support functors
|
||||
template <typename Op, typename src0_t, typename dst_t>
|
||||
static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) {
|
||||
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
|
||||
constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
y[i] = f32_to_dst(op(src0_to_f32(x[i])));
|
||||
}
|
||||
}
|
||||
|
||||
// Extend apply_unary_op to support functors
|
||||
template <typename Op, typename src0_t, typename dst_t>
|
||||
static void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(dst_t));
|
||||
GGML_ASSERT(nb00 == sizeof(src0_t));
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
||||
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
|
||||
vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op);
|
||||
}
|
||||
}
|
||||
|
||||
// Generic dispatcher for functors
|
||||
template <typename Op>
|
||||
static void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
/* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
|
||||
apply_unary_op_functor<Op, float, float>(params, dst, op);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
|
||||
apply_unary_op_functor<Op, ggml_fp16_t, ggml_fp16_t>(params, dst, op);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
|
||||
apply_unary_op_functor<Op, ggml_bf16_t, ggml_bf16_t>(params, dst, op);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
|
||||
apply_unary_op_functor<Op, ggml_bf16_t, float>(params, dst, op);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||
apply_unary_op_functor<Op, ggml_fp16_t, float>(params, dst, op);
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
|
||||
ggml_type_name(dst->type), ggml_type_name(src0->type));
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_abs>(params, dst);
|
||||
}
|
||||
|
|
@ -184,3 +273,17 @@ void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor *
|
|||
void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_log>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const float alpha_n = ggml_get_op_params_f32(dst, 1);
|
||||
const float alpha_p = ggml_get_op_params_f32(dst, 2);
|
||||
const float beta = ggml_get_op_params_f32(dst, 3);
|
||||
const float eps = ggml_get_op_params_f32(dst, 4);
|
||||
|
||||
const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
|
||||
return op_xielu(f, alpha_n, alpha_p, beta, eps);
|
||||
};
|
||||
|
||||
unary_op_functor(params, dst, xielu_op_params);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
|
|||
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2334,6 +2334,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_UNARY_OP_ELU:
|
||||
ggml_cuda_op_elu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_XIELU:
|
||||
ggml_cuda_op_xielu(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include "unary.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
static __device__ __forceinline__ float op_abs(float x) {
|
||||
return fabsf(x);
|
||||
|
|
@ -375,6 +376,59 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|||
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
|
||||
}
|
||||
|
||||
/* CUDA kernel + launcher for xIELU */
|
||||
|
||||
template <typename T>
|
||||
static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float xi = ggml_cuda_cast<float>(x[i]);
|
||||
|
||||
const float gate_pos = (xi > 0.0f);
|
||||
const float y_pos = alpha_p * xi * xi + beta * xi;
|
||||
const float min_v_eps = fminf(xi, eps);
|
||||
const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi;
|
||||
const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg;
|
||||
|
||||
dst[i] = ggml_cuda_cast<T>(out);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE;
|
||||
xielu_kernel<<<num_blocks, CUDA_XIELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, alpha_n, alpha_p, beta, eps);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const void * src0_d = src0->data;
|
||||
void * dst_d = dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
|
||||
const float alpha_n = ggml_get_op_params_f32(dst, 1);
|
||||
const float alpha_p = ggml_get_op_params_f32(dst, 2);
|
||||
const float beta = ggml_get_op_params_f32(dst, 3);
|
||||
const float eps = ggml_get_op_params_f32(dst, 4);
|
||||
|
||||
if (src0->type == GGML_TYPE_F16) {
|
||||
xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);
|
||||
} else {
|
||||
xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* silu_back */
|
||||
|
||||
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
#define CUDA_SIN_BLOCK_SIZE 256
|
||||
#define CUDA_COS_BLOCK_SIZE 256
|
||||
#define CUDA_GLU_BLOCK_SIZE 256
|
||||
#define CUDA_XIELU_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
|
|
@ -72,3 +73,5 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|||
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -102,6 +102,9 @@ static bool ggml_op_is_empty(enum ggml_op op) {
|
|||
}
|
||||
}
|
||||
|
||||
static inline float ggml_softplus(float input) {
|
||||
return (input > 20.0f) ? input : logf(1 + expf(input));
|
||||
}
|
||||
//
|
||||
// logging
|
||||
//
|
||||
|
|
|
|||
|
|
@ -1143,10 +1143,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
|||
"HARDSIGMOID",
|
||||
"EXP",
|
||||
"GELU_ERF",
|
||||
"XIELU",
|
||||
};
|
||||
|
||||
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
|
||||
|
||||
static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16");
|
||||
|
||||
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
|
||||
"REGLU",
|
||||
|
|
@ -2652,6 +2652,29 @@ struct ggml_tensor * ggml_silu_inplace(
|
|||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
|
||||
}
|
||||
|
||||
// ggml_xielu
|
||||
|
||||
struct ggml_tensor * ggml_xielu(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float alpha_n,
|
||||
float alpha_p,
|
||||
float beta,
|
||||
float eps) {
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU);
|
||||
ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n));
|
||||
ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p));
|
||||
ggml_set_op_params_f32(result, 3, beta);
|
||||
ggml_set_op_params_f32(result, 4, eps);
|
||||
|
||||
result->op = GGML_OP_UNARY;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_silu_back
|
||||
|
||||
struct ggml_tensor * ggml_silu_back(
|
||||
|
|
|
|||
|
|
@ -297,6 +297,13 @@ class Keys:
|
|||
class Diffusion:
|
||||
SHIFT_LOGITS = "diffusion.shift_logits"
|
||||
|
||||
class xIELU:
|
||||
ALPHA_P = "xielu.alpha_p"
|
||||
ALPHA_N = "xielu.alpha_n"
|
||||
BETA = "xielu.beta"
|
||||
EPS = "xielu.eps"
|
||||
|
||||
|
||||
#
|
||||
# recommended mapping of model tensor names for storage in gguf
|
||||
#
|
||||
|
|
@ -405,6 +412,7 @@ class MODEL_ARCH(IntEnum):
|
|||
LLADA_MOE = auto()
|
||||
SEED_OSS = auto()
|
||||
GROVEMOE = auto()
|
||||
APERTUS = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
|
|
@ -746,6 +754,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.LLADA_MOE: "llada-moe",
|
||||
MODEL_ARCH.SEED_OSS: "seed_oss",
|
||||
MODEL_ARCH.GROVEMOE: "grovemoe",
|
||||
MODEL_ARCH.APERTUS: "apertus",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
|
|
@ -2706,6 +2715,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.APERTUS: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.LLADA_MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
|||
|
|
@ -1084,6 +1084,18 @@ class GGUFWriter:
|
|||
def add_audio_stack_factor(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value)
|
||||
|
||||
def add_xielu_alpha_p(self, values: Sequence[float]):
|
||||
self.add_array(Keys.xIELU.ALPHA_P, values)
|
||||
|
||||
def add_xielu_alpha_n(self, values: Sequence[float]):
|
||||
self.add_array(Keys.xIELU.ALPHA_N, values)
|
||||
|
||||
def add_xielu_beta(self, values: Sequence[float]):
|
||||
self.add_array(Keys.xIELU.BETA, values)
|
||||
|
||||
def add_xielu_eps(self, values: Sequence[float]):
|
||||
self.add_array(Keys.xIELU.EPS, values)
|
||||
|
||||
# diffusion models
|
||||
|
||||
def add_diffusion_shift_logits(self, value: bool) -> None:
|
||||
|
|
|
|||
|
|
@ -148,6 +148,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.operator_norm", # lfm2
|
||||
"model.transformer.blocks.{bid}.attn_norm", # llada
|
||||
"layers.{bid}.input_layernorm", # qwen3-embedding
|
||||
"model.layers.{bid}.attention_layernorm" # apertus
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
|
|
@ -325,6 +326,7 @@ class TensorNameMap:
|
|||
"model.layers.layers.{bid}.pre_mlp_norm", # plamo2
|
||||
"model.transformer.blocks.{bid}.ff_norm", # llada
|
||||
"layers.{bid}.post_attention_layernorm", # qwen3-embedding
|
||||
"model.layers.{bid}.feedforward_layernorm", # apertus
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
|
|
@ -547,6 +549,7 @@ class TensorNameMap:
|
|||
"transformer.layers.{bid}.attn.q_norm", # openelm
|
||||
"model.layers.layers.{bid}.mixer.q", # plamo2
|
||||
"layers.{bid}.self_attn.q_norm", # qwen3-embedding
|
||||
"model.layers.{bid}.attention.query_layernorm", # apertus
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_K_NORM: (
|
||||
|
|
@ -560,6 +563,7 @@ class TensorNameMap:
|
|||
"transformer.layers.{bid}.attn.k_norm", # openelm
|
||||
"model.layers.layers.{bid}.mixer.k", # plamo2
|
||||
"layers.{bid}.self_attn.k_norm", # qwen3-embedding
|
||||
"model.layers.{bid}.attention.key_layernorm", # apertus
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ROPE_FREQS: (
|
||||
|
|
|
|||
|
|
@ -0,0 +1,327 @@
|
|||
{%- macro render_typescript_type(param_spec, required_params, is_nullable=false) -%}
|
||||
{%- if param_spec.type == "array" -%}
|
||||
{%- if param_spec['items'] -%}
|
||||
{%- if param_spec['items']['type'] == "string" -%}
|
||||
{{- "string[]" }}
|
||||
{%- elif param_spec['items']['type'] == "number" -%}
|
||||
{{- "number[]" }}
|
||||
{%- elif param_spec['items']['type'] == "integer" -%}
|
||||
{{- "number[]" }}
|
||||
{%- elif param_spec['items']['type'] == "boolean" -%}
|
||||
{{- "boolean[]" }}
|
||||
{%- else -%}
|
||||
{%- set inner_type = render_typescript_type(param_spec['items'], required_params) -%}
|
||||
{%- if inner_type == "object | object" or inner_type|length > 50 -%}
|
||||
{{- "any[]" }}
|
||||
{%- else -%}
|
||||
{{- inner_type + "[]" }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if param_spec.nullable -%}
|
||||
{{- " | null" }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{- "any[]" }}
|
||||
{%- if param_spec.nullable -%}
|
||||
{{- " | null" }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- elif param_spec.type is defined and param_spec.type is iterable and param_spec.type is not string and param_spec.type is not mapping and param_spec.type[0] is defined -%}
|
||||
{#- Handle array of types like ["object", "object"] from Union[dict, list] #}
|
||||
{%- if param_spec.type | length > 1 -%}
|
||||
{{- param_spec.type | join(" | ") }}
|
||||
{%- else -%}
|
||||
{{- param_spec.type[0] }}
|
||||
{%- endif -%}
|
||||
{%- elif param_spec.oneOf -%}
|
||||
{#- Handle oneOf schemas - check for complex unions and fallback to any #}
|
||||
{%- set has_object_variants = false -%}
|
||||
{%- for variant in param_spec.oneOf -%}
|
||||
{%- if variant.type == "object" -%}
|
||||
{%- set has_object_variants = true -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if has_object_variants and param_spec.oneOf|length > 1 -%}
|
||||
{{- "any" }}
|
||||
{%- else -%}
|
||||
{%- for variant in param_spec.oneOf -%}
|
||||
{{- render_typescript_type(variant, required_params) -}}
|
||||
{%- if variant.description %}
|
||||
{{- "// " + variant.description }}
|
||||
{%- endif -%}
|
||||
{%- if variant.default is defined %}
|
||||
{{ "// default: " + variant.default|tojson }}
|
||||
{%- endif -%}
|
||||
{%- if not loop.last %}
|
||||
{{- " | " }}
|
||||
{% endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- elif param_spec.type == "string" -%}
|
||||
{%- if param_spec.enum -%}
|
||||
{{- '"' + param_spec.enum|join('" | "') + '"' -}}
|
||||
{%- else -%}
|
||||
{{- "string" }}
|
||||
{%- if param_spec.nullable %}
|
||||
{{- " | null" }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- elif param_spec.type == "number" -%}
|
||||
{{- "number" }}
|
||||
{%- elif param_spec.type == "integer" -%}
|
||||
{{- "number" }}
|
||||
{%- elif param_spec.type == "boolean" -%}
|
||||
{{- "boolean" }}
|
||||
{%- elif param_spec.type == "object" -%}
|
||||
{%- if param_spec.properties -%}
|
||||
{{- "{\n" }}
|
||||
{%- for prop_name, prop_spec in param_spec.properties.items() -%}
|
||||
{{- prop_name -}}
|
||||
{%- if prop_name not in (param_spec.required or []) -%}
|
||||
{{- "?" }}
|
||||
{%- endif -%}
|
||||
{{- ": " }}
|
||||
{{ render_typescript_type(prop_spec, param_spec.required or []) }}
|
||||
{%- if not loop.last -%}
|
||||
{{-", " }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- "}" }}
|
||||
{%- else -%}
|
||||
{{- "object" }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{- "any" }}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- macro render_tools(tools) -%}
|
||||
{%- for tool in tools %}
|
||||
{{- "// " + tool.description + "\n" }}
|
||||
{{- "type "+ tool.name + " = " }}
|
||||
{%- if tool.parameters and tool.parameters.properties %}
|
||||
{{- "(_: {\n" }}
|
||||
{%- for param_name, param_spec in tool.parameters.properties.items() %}
|
||||
{%- if param_spec.description %}
|
||||
{{- "// " + param_spec.description + "\n" }}
|
||||
{%- endif %}
|
||||
{{- param_name }}
|
||||
{%- if param_name not in (tool.parameters.required or []) -%}
|
||||
{{- "?" }}
|
||||
{%- endif -%}
|
||||
{{- ": " }}
|
||||
{{- render_typescript_type(param_spec, tool.parameters.required or []) }}
|
||||
{%- if param_spec.default is defined -%}
|
||||
{%- if param_spec.enum %}
|
||||
{{- ", // default: " + param_spec.default }}
|
||||
{%- elif param_spec.oneOf %}
|
||||
{{- "// default: " + param_spec.default }}
|
||||
{%- else %}
|
||||
{{- ", // default: " + param_spec.default|tojson }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if not loop.last %}
|
||||
{{- ",\n" }}
|
||||
{%- else %}
|
||||
{{- "\n" }}
|
||||
{%- endif -%}
|
||||
{%- endfor %}
|
||||
{{- "}) => any;" }}
|
||||
{%- else -%}
|
||||
{{- "() => any;" }}
|
||||
{%- endif -%}
|
||||
{%- if not loop.last -%}
|
||||
{{- "\n" }}
|
||||
{%- endif -%}
|
||||
{%- endfor %}
|
||||
{%- endmacro -%}
|
||||
|
||||
{{ bos_token }}
|
||||
|
||||
{%- set system_token = '<|system_start|>' -%}
|
||||
{%- set end_system_token = '<|system_end|>' -%}
|
||||
{%- set developer_token = '<|developer_start|>' -%}
|
||||
{%- set end_developer_token = '<|developer_end|>' -%}
|
||||
{%- set user_token = '<|user_start|>' -%}
|
||||
{%- set end_user_token = '<|user_end|>' -%}
|
||||
{%- set assistant_token = '<|assistant_start|>' -%}
|
||||
{%- set end_assistant_token = '<|assistant_end|>' -%}
|
||||
{%- set inner_token = '<|inner_prefix|>' -%}
|
||||
{%- set outer_token = '<|inner_suffix|>' -%}
|
||||
{%- set tool_calls_token = '<|tools_prefix|>' -%}
|
||||
{%- set end_tool_calls_token = '<|tools_suffix|>' -%}
|
||||
|
||||
{%- set ns = namespace(in_assistant=false, in_tool=false, in_inner=false, assistant_format=none) -%}
|
||||
|
||||
{%- if messages and messages[0].role == 'system' -%}
|
||||
{%- if "content" in messages[0] -%}
|
||||
{%- if messages[0].content is string -%}
|
||||
{{ system_token + messages[0].content + end_system_token }}
|
||||
{%- elif messages[0].content is mapping and "text" in messages[0].content -%}
|
||||
{{ system_token + messages[0].content.text + end_system_token }}
|
||||
{%- else -%}
|
||||
{{- raise_exception("Invalid system message") -}}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{- raise_exception("Invalid system message") -}}
|
||||
{%- endif -%}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{{ system_token + 'You are Apertus, a helpful assistant created by the SwissAI initiative.\nKnowledge cutoff: 2024-04\nCurrent date: ' + strftime_now('%Y-%m-%d') + end_system_token }}
|
||||
{%- set loop_messages = messages -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{ developer_token + 'Deliberation: ' }}
|
||||
{%- if enable_thinking is defined and enable_thinking -%}
|
||||
{{ 'enabled\n' }}
|
||||
{%- else -%}
|
||||
{{ 'disabled\n' }}
|
||||
{%- endif -%}
|
||||
{%- if tools is defined and tools -%}
|
||||
{{ 'Tool Capabilities:\n' + render_tools(tools) }}
|
||||
{%- else -%}
|
||||
{{ 'Tool Capabilities: disabled' }}
|
||||
{%- endif -%}
|
||||
{{ end_developer_token }}
|
||||
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if message.role == 'user' -%}
|
||||
{%- set ns.in_inner = false -%}
|
||||
{%- if ns.in_tool -%}
|
||||
{{ ']' }}
|
||||
{%- set ns.in_tool = false -%}
|
||||
{%- endif -%}
|
||||
{%- if ns.in_assistant -%}
|
||||
{{ end_assistant_token }}
|
||||
{%- set ns.in_assistant = false -%}
|
||||
{%- endif -%}
|
||||
{%- if "content" in message -%}
|
||||
{{ user_token }}
|
||||
{%- if message.content is string -%}
|
||||
{{ message.content }}
|
||||
{%- elif message.content is mapping and "parts" in message.content -%}
|
||||
{%- set parts = message.content.parts -%}
|
||||
{%- for part in parts -%}
|
||||
{%- if part.type == "text" -%}
|
||||
{{ part.text }}
|
||||
{%- else -%}
|
||||
{{- raise_exception("Invalid user part: " + part.type) -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{- raise_exception("Invalid user message: " + message.role) -}}
|
||||
{%- endif -%}
|
||||
{{ end_user_token }}
|
||||
{%- endif -%}
|
||||
{%- elif message.role == 'assistant' -%}
|
||||
{%- if not ns.in_assistant -%}
|
||||
{{ assistant_token }}
|
||||
{%- set ns.in_assistant = true -%}
|
||||
{%- endif -%}
|
||||
{%- if "content" in message and message.content is not none -%}
|
||||
{%- if message.content is string and (ns.assistant_format is none or ns.assistant_format == "string") -%}
|
||||
{%- if ns.in_tool -%}
|
||||
{{ ']' }}
|
||||
{%- set ns.in_tool = false -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.assistant_format = "string" -%}
|
||||
{{ message.content }}
|
||||
{%- elif message.content is mapping and "blocks" in message.content and (ns.assistant_format is none or ns.assistant_format == "mapping") -%}
|
||||
{%- set ns.assistant_format = "mapping" -%}
|
||||
{%- set blocks = message.content.blocks -%}
|
||||
{%- for block in blocks -%}
|
||||
{%- if block.type == 'thoughts' -%}
|
||||
{%- if ns.in_tool -%}
|
||||
{{ ']' }}
|
||||
{%- set ns.in_tool = false -%}
|
||||
{%- endif -%}
|
||||
{%- if not ns.in_inner -%}
|
||||
{%- set ns.in_inner = true -%}
|
||||
{{ inner_token }}
|
||||
{%- endif -%}
|
||||
{{ block.text }}
|
||||
{%- elif block.type == 'tool_calls' -%}
|
||||
{%- if ns.in_tool -%}
|
||||
{{ ']' }}
|
||||
{%- set ns.in_tool = false -%}
|
||||
{%- endif -%}
|
||||
{%- if ns.in_inner and not loop.first and block.calls|length == 1 and block.calls[0].name == 'display_answers' -%}
|
||||
{%- set ns.in_inner = false -%}
|
||||
{{ outer_token }}
|
||||
{%- endif -%}
|
||||
{{ tool_calls_token + '[' }}
|
||||
{%- for tool_call in block.calls -%}
|
||||
{{- '{"' + tool_call.name + '": ' + tool_call.arguments + '}' }}
|
||||
{%- if not loop.last -%}
|
||||
{{- ", " }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{ ']' + end_tool_calls_token }}
|
||||
{%- elif block.type == 'tool_outputs' -%}
|
||||
{%- if ns.in_tool -%}
|
||||
{{- raise_exception("Cannot have both tool outputs as separate messages and tool outputs as blocks") -}}
|
||||
{%- endif -%}
|
||||
{{ '[' }}
|
||||
{%- for tool_output in block.outputs -%}
|
||||
{{- tool_output.output }}
|
||||
{%- if not loop.last -%}
|
||||
{{- ", " }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ']' }}
|
||||
{%- elif block.type == 'response' -%}
|
||||
{%- if ns.in_tool -%}
|
||||
{{ ']' }}
|
||||
{%- set ns.in_tool = false -%}
|
||||
{%- endif -%}
|
||||
{%- if (not loop.first and ns.in_inner) or (ns.in_assistant and ns.in_inner) -%}
|
||||
{%- set ns.in_inner = false -%}
|
||||
{{ outer_token }}
|
||||
{%- endif -%}
|
||||
{{ block.text }}
|
||||
{%- else -%}
|
||||
{{- raise_exception("Invalid assistant block type: " + block.type) -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{- raise_exception("Invalid assistant content '" + message.content + "', expected " + ns.assistant_format) -}}
|
||||
{%- endif -%}
|
||||
{%- elif "tool_calls" not in message -%}
|
||||
{{- raise_exception("Invalid assistant message " + message) -}}
|
||||
{%- endif -%}
|
||||
{%- if "tool_calls" in message and message.tool_calls -%}
|
||||
{{ tool_calls_token + '[' }}
|
||||
{%- for tool_call in message.tool_calls -%}
|
||||
{%- if tool_call.type == 'function' -%}
|
||||
{%- set function = tool_call.function -%}
|
||||
{{- '{"' + function.name + '": ' + function.arguments + '}' }}
|
||||
{%- if not loop.last -%}
|
||||
{{- ", " }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{- raise_exception("Invalid tool call type: " + tool_call.type) -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{ ']' + end_tool_calls_token }}
|
||||
{%- endif -%}
|
||||
{%- elif message.role == 'tool' -%}
|
||||
{%- if not ns.in_assistant -%}
|
||||
{{- raise_exception("Tool message outside of assistant") -}}
|
||||
{%- endif -%}
|
||||
{%- if not ns.in_tool -%}
|
||||
{{ '[' }}
|
||||
{%- set ns.in_tool = true -%}
|
||||
{%- else -%}
|
||||
{{ ", "}}
|
||||
{%- endif -%}
|
||||
{{ message.content }}
|
||||
{%- else -%}
|
||||
{{- raise_exception("Invalid message role") -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if ns.in_tool -%}
|
||||
{{ ']' }}
|
||||
{%- endif -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{ assistant_token }}
|
||||
{%- endif -%}
|
||||
|
|
@ -99,6 +99,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_LLADA_MOE, "llada-moe" },
|
||||
{ LLM_ARCH_SEED_OSS, "seed_oss" },
|
||||
{ LLM_ARCH_GROVEMOE, "grovemoe" },
|
||||
{ LLM_ARCH_APERTUS, "apertus" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
|
@ -256,6 +257,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" },
|
||||
{ LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" },
|
||||
|
||||
{ LLM_KV_XIELU_ALPHA_N, "xielu.alpha_n" },
|
||||
{ LLM_KV_XIELU_ALPHA_P, "xielu.alpha_p" },
|
||||
{ LLM_KV_XIELU_BETA, "xielu.beta" },
|
||||
{ LLM_KV_XIELU_EPS, "xielu.eps" },
|
||||
|
||||
// deprecated
|
||||
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
|
||||
{ LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" },
|
||||
|
|
@ -2119,6 +2125,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_APERTUS,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_DREAM,
|
||||
{
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ enum llm_arch {
|
|||
LLM_ARCH_LLADA_MOE,
|
||||
LLM_ARCH_SEED_OSS,
|
||||
LLM_ARCH_GROVEMOE,
|
||||
LLM_ARCH_APERTUS,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
|
@ -260,6 +261,11 @@ enum llm_kv {
|
|||
|
||||
LLM_KV_SHORTCONV_L_CACHE,
|
||||
|
||||
LLM_KV_XIELU_ALPHA_N,
|
||||
LLM_KV_XIELU_ALPHA_P,
|
||||
LLM_KV_XIELU_BETA,
|
||||
LLM_KV_XIELU_EPS,
|
||||
|
||||
// deprecated:
|
||||
LLM_KV_TOKENIZER_PREFIX_ID,
|
||||
LLM_KV_TOKENIZER_SUFFIX_ID,
|
||||
|
|
|
|||
|
|
@ -169,6 +169,12 @@ struct llama_hparams {
|
|||
uint32_t laurel_rank = 64;
|
||||
uint32_t n_embd_altup = 256;
|
||||
|
||||
// xIELU
|
||||
std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_n;
|
||||
std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_p;
|
||||
std::array<float, LLAMA_MAX_LAYERS> xielu_beta;
|
||||
std::array<float, LLAMA_MAX_LAYERS> xielu_eps;
|
||||
|
||||
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
||||
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
||||
|
|
|
|||
|
|
@ -465,6 +465,8 @@ namespace GGUFMeta {
|
|||
// TODO: this is not very clever - figure out something better
|
||||
template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
|
||||
template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
|
||||
template bool llama_model_loader::get_key_or_arr<std::array<float, 512>>(enum llm_kv kid, std::array<float, 512> & result, uint32_t n, bool required);
|
||||
|
||||
|
||||
llama_model_loader::llama_model_loader(
|
||||
const std::string & fname,
|
||||
|
|
|
|||
|
|
@ -512,9 +512,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
llm_arch_is_recurrent(ml.get_arch()));
|
||||
|
||||
std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
|
||||
|
||||
std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0);
|
||||
|
||||
std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0.0f);
|
||||
std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f);
|
||||
std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f);
|
||||
std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f);
|
||||
|
||||
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false);
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
|
||||
|
||||
|
|
@ -2033,6 +2037,19 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_APERTUS:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer);
|
||||
ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer);
|
||||
ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer);
|
||||
ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 32: type = LLM_TYPE_8B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
|
|
@ -5915,6 +5932,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_up_chexps = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_APERTUS:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
|
||||
if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
|
||||
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
} else {
|
||||
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
}
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
|
||||
|
||||
// optional bias tensors
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED);
|
||||
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
|
||||
|
||||
// Q and K layernorms for Apertus
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
|
|
@ -19099,6 +19158,141 @@ struct llm_build_grovemoe : public llm_graph_context {
|
|||
}
|
||||
};
|
||||
|
||||
struct llm_build_apertus : public llm_graph_context {
|
||||
llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, nullptr,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur_pos", il);
|
||||
cb(Kcur, "Kcur_pos", il);
|
||||
cb(Vcur, "Vcur_pos", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network with xIELU activation
|
||||
{
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, nullptr,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// Up projection
|
||||
ggml_tensor * up = build_lora_mm(model.layers[il].ffn_up, cur);
|
||||
cb(up, "ffn_up", il);
|
||||
|
||||
float alpha_n_val = hparams.xielu_alpha_n[il];
|
||||
float alpha_p_val = hparams.xielu_alpha_p[il];
|
||||
float beta_val = hparams.xielu_beta[il];
|
||||
float eps_val = hparams.xielu_eps[il];
|
||||
|
||||
// Apply xIELU activation
|
||||
ggml_tensor * activated = ggml_xielu(ctx0, up, alpha_n_val, alpha_p_val, beta_val, eps_val);
|
||||
cb(activated, "ffn_xielu", il);
|
||||
|
||||
// Down projection
|
||||
cur = build_lora_mm(model.layers[il].ffn_down, activated);
|
||||
cb(cur, "ffn_down", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, nullptr,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
|
||||
llama_memory_i * res;
|
||||
|
||||
|
|
@ -19629,6 +19823,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
{
|
||||
llm = std::make_unique<llm_build_grovemoe>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_APERTUS:
|
||||
{
|
||||
llm = std::make_unique<llm_build_apertus>(*this, params);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
@ -19835,6 +20033,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_GLM4_MOE:
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
case LLM_ARCH_GROVEMOE:
|
||||
case LLM_ARCH_APERTUS:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
|
|
|||
|
|
@ -380,6 +380,12 @@ struct llama_layer {
|
|||
// openai-moe
|
||||
struct ggml_tensor * attn_sinks = nullptr;
|
||||
|
||||
// xIELU activation parameters for Apertus
|
||||
struct ggml_tensor * ffn_act_alpha_n = nullptr;
|
||||
struct ggml_tensor * ffn_act_alpha_p = nullptr;
|
||||
struct ggml_tensor * ffn_act_beta = nullptr;
|
||||
struct ggml_tensor * ffn_act_eps = nullptr;
|
||||
|
||||
struct llama_layer_posnet posnet;
|
||||
|
||||
struct llama_layer_convnext convnext;
|
||||
|
|
|
|||
|
|
@ -2054,6 +2054,79 @@ static void test_template_output_parsers() {
|
|||
/* .parse_tool_calls = */ true,
|
||||
}));
|
||||
}
|
||||
{
|
||||
auto tmpls = read_templates("models/templates/Apertus-8B-Instruct.jinja");
|
||||
std::vector<std::string> end_tokens{ "<|assistant_end|>" };
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_APERTUS, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_APERTUS, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
||||
|
||||
// Test parsing regular content
|
||||
assert_msg_equals(message_assist,
|
||||
common_chat_parse(
|
||||
"Hello, world!\nWhat's up?",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_APERTUS}));
|
||||
|
||||
// Test parsing content with thinking
|
||||
assert_msg_equals(message_assist_thoughts,
|
||||
common_chat_parse(
|
||||
"<|inner_prefix|>I'm\nthinking<|inner_suffix|>Hello, world!\nWhat's up?",
|
||||
/* is_partial= */ false,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_APERTUS,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
}));
|
||||
|
||||
// Test parsing tool calls
|
||||
assert_msg_equals(message_assist_call,
|
||||
common_chat_parse(
|
||||
"<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_APERTUS}));
|
||||
|
||||
// Test parsing tool calls with thinking
|
||||
assert_msg_equals(message_assist_call_thoughts,
|
||||
common_chat_parse(
|
||||
"<|inner_prefix|>I'm\nthinking<|inner_suffix|><|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>",
|
||||
/* is_partial= */ false,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_APERTUS,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK
|
||||
}));
|
||||
|
||||
// Test tool calls with extra content
|
||||
assert_msg_equals(message_assist_call_content,
|
||||
common_chat_parse(
|
||||
"<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>Hello, world!\nWhat's up?",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_APERTUS}
|
||||
));
|
||||
|
||||
// Test tool calls with extra content AND thinking
|
||||
assert_msg_equals(message_assist_call_thoughts_content,
|
||||
common_chat_parse(
|
||||
"<|inner_prefix|>I'm\nthinking<|inner_suffix|><|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>Hello, world!\nWhat's up?",
|
||||
/* is_partial= */ false,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_APERTUS,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK
|
||||
}));
|
||||
|
||||
// Test template generation for regular content
|
||||
test_templates(tmpls.get(), end_tokens, message_assist, tools,
|
||||
"Hello, world!\nWhat's up?",
|
||||
/* expect_grammar_triggered= */ false);
|
||||
|
||||
// Test template generation for tool calls
|
||||
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
||||
"<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>",
|
||||
/* expect_grammar_triggered= */ true
|
||||
);
|
||||
|
||||
assert_equals(true, common_chat_templates_support_enable_thinking(tmpls.get()));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static void test_msg_diffs_compute() {
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import re
|
|||
from safetensors.torch import save_file
|
||||
|
||||
# default
|
||||
model_path = './model.pt';
|
||||
model_path = './model.pt'
|
||||
|
||||
# read from CLI
|
||||
if len(sys.argv) > 1:
|
||||
|
|
|
|||
Loading…
Reference in New Issue