Merge remote-tracking branch 'upstream/master' into backend-sampling

This commit is contained in:
Daniel Bevenius 2025-12-19 09:38:01 +01:00
commit bc5195c585
No known key found for this signature in database
46 changed files with 1876 additions and 500 deletions

View File

@ -873,7 +873,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
sampler_type_chars += common_sampler_type_to_chr(sampler); sampler_type_chars += common_sampler_type_to_chr(sampler);
sampler_type_names += common_sampler_type_to_str(sampler) + ";"; sampler_type_names += common_sampler_type_to_str(sampler) + ";";
} }
sampler_type_names.pop_back(); if (!sampler_type_names.empty()) {
sampler_type_names.pop_back(); // remove last semicolon
}
/** /**
@ -1194,7 +1196,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.system_prompt = value; params.system_prompt = value;
} }
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DIFFUSION})); ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DIFFUSION, LLAMA_EXAMPLE_MTMD}));
add_opt(common_arg( add_opt(common_arg(
{"--perf"}, {"--perf"},
{"--no-perf"}, {"--no-perf"},

View File

@ -712,6 +712,9 @@ class ModelBase:
if "thinker_config" in config: if "thinker_config" in config:
# rename for Qwen2.5-Omni # rename for Qwen2.5-Omni
config["text_config"] = config["thinker_config"]["text_config"] config["text_config"] = config["thinker_config"]["text_config"]
if "lfm" in config:
# rename for LFM2-Audio
config["text_config"] = config["lfm"]
return config return config
@classmethod @classmethod
@ -9713,12 +9716,12 @@ class LFM2Model(TextModel):
self._add_feed_forward_length() self._add_feed_forward_length()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name if self._is_vision_tensor(name) or self._is_audio_tensor(name):
if is_vision_tensor: # skip multimodal tensors
# skip vision tensors
return [] return []
name = name.replace("language_model.", "") name = name.replace("language_model.", "") # vision
name = name.replace("lfm.", "model.") # audio
# conv op requires 2d tensor # conv op requires 2d tensor
if 'conv.conv' in name: if 'conv.conv' in name:
@ -9726,6 +9729,12 @@ class LFM2Model(TextModel):
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
def _is_vision_tensor(self, name: str) -> bool:
return "vision_tower" in name or "multi_modal_projector" in name
def _is_audio_tensor(self, name: str):
return any(p in name for p in ["audio", "codebook", "conformer", "depth_embedding", "depthformer", "depth_linear"])
@ModelBase.register("Lfm2MoeForCausalLM") @ModelBase.register("Lfm2MoeForCausalLM")
class LFM2MoeModel(TextModel): class LFM2MoeModel(TextModel):
@ -9831,6 +9840,81 @@ class LFM2VLModel(MmprojModel):
return [] # skip other tensors return [] # skip other tensors
@ModelBase.register("Lfm2AudioForConditionalGeneration")
class LFM2AudioModel(MmprojModel):
has_vision_encoder = False
has_audio_encoder = True
model_name = "Lfm2AudioEncoder"
_batch_norm_tensors: list[dict[str, Tensor]] | None = None
def get_audio_config(self) -> dict[str, Any] | None:
return self.global_config.get("encoder")
def set_gguf_parameters(self):
assert self.hparams_audio is not None
self.hparams_audio["hidden_size"] = self.hparams_audio["d_model"]
self.hparams_audio["intermediate_size"] = self.hparams_audio["d_model"]
self.hparams_audio["num_attention_heads"] = self.hparams_audio["n_heads"]
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LFM2A)
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["feat_in"])
self.gguf_writer.add_audio_attention_layernorm_eps(1e-5)
def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".conv" in name and ".weight" in name:
return gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip language model tensors
if name.startswith("lfm."):
return []
# for training only
if any(p in name for p in ["audio_loss_weight"]):
return []
# for audio output
if any(p in name for p in ["codebook_offsets", "depth_embeddings", "depth_linear", "depthformer"]):
return []
# fold running_mean, running_var and eps into weight and bias for batch_norm
if "batch_norm" in name:
if self._batch_norm_tensors is None:
self._batch_norm_tensors = [{} for _ in range(self.block_count)]
assert bid is not None
self._batch_norm_tensors[bid][name] = data_torch
if len(self._batch_norm_tensors[bid]) < 5:
return []
weight = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.weight"]
bias = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.bias"]
running_mean = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.running_mean"]
running_var = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.running_var"]
eps = 1e-5 # default value
a = weight / torch.sqrt(running_var + eps)
b = bias - running_mean * a
return [
(self.map_tensor_name(f"conformer.layers.{bid}.conv.batch_norm.weight"), a),
(self.map_tensor_name(f"conformer.layers.{bid}.conv.batch_norm.bias"), b),
]
# reshape conv weights
if name.startswith("conformer.pre_encode.conv.") and name.endswith(".bias"):
data_torch = data_torch[:, None, None]
if "conv.depthwise_conv" in name and name.endswith(".weight"):
assert data_torch.shape[1] == 1
data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[2])
if "conv.pointwise_conv" in name and name.endswith(".weight"):
assert data_torch.shape[2] == 1
data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[1])
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("SmallThinkerForCausalLM") @ModelBase.register("SmallThinkerForCausalLM")
class SmallThinkerModel(TextModel): class SmallThinkerModel(TextModel):
model_arch = gguf.MODEL_ARCH.SMALLTHINKER model_arch = gguf.MODEL_ARCH.SMALLTHINKER

View File

@ -1,27 +1,27 @@
# Android # Android
## Build with Android Studio ## Build GUI binding using Android Studio
Import the `examples/llama.android` directory into Android Studio, then perform a Gradle sync and build the project. Import the `examples/llama.android` directory into Android Studio, then perform a Gradle sync and build the project.
![Project imported into Android Studio](./android/imported-into-android-studio.png) ![Project imported into Android Studio](./android/imported-into-android-studio.jpg)
This Android binding supports hardware acceleration up to `SME2` for **Arm** and `AMX` for **x86-64** CPUs on Android and ChromeOS devices. This Android binding supports hardware acceleration up to `SME2` for **Arm** and `AMX` for **x86-64** CPUs on Android and ChromeOS devices.
It automatically detects the host's hardware to load compatible kernels. As a result, it runs seamlessly on both the latest premium devices and older devices that may lack modern CPU features or have limited RAM, without requiring any manual configuration. It automatically detects the host's hardware to load compatible kernels. As a result, it runs seamlessly on both the latest premium devices and older devices that may lack modern CPU features or have limited RAM, without requiring any manual configuration.
A minimal Android app frontend is included to showcase the bindings core functionalities: A minimal Android app frontend is included to showcase the bindings core functionalities:
1. **Parse GGUF metadata** via `GgufMetadataReader` from either a `ContentResolver` provided `Uri` or a local `File`. 1. **Parse GGUF metadata** via `GgufMetadataReader` from either a `ContentResolver` provided `Uri` from shared storage, or a local `File` from your app's private storage.
2. **Obtain a `TierDetection` or `InferenceEngine`** instance through the high-level facade APIs. 2. **Obtain a `InferenceEngine`** instance through the `AiChat` facade and load your selected model via its app-private file path.
3. **Send a raw user prompt** for automatic template formatting, prefill, and decoding. Then collect the generated tokens in a Kotlin `Flow`. 3. **Send a raw user prompt** for automatic template formatting, prefill, and batch decoding. Then collect the generated tokens in a Kotlin `Flow`.
For a production-ready experience that leverages advanced features such as system prompts and benchmarks, check out [Arm AI Chat](https://play.google.com/store/apps/details?id=com.arm.aichat) on Google Play. For a production-ready experience that leverages advanced features such as system prompts and benchmarks, plus friendly UI features such as model management and Arm feature visualizer, check out [Arm AI Chat](https://play.google.com/store/apps/details?id=com.arm.aichat) on Google Play.
This project is made possible through a collaborative effort by Arm's **CT-ML**, **CE-ML** and **STE** groups: This project is made possible through a collaborative effort by Arm's **CT-ML**, **CE-ML** and **STE** groups:
| ![Home screen](./android/arm-ai-chat-home-screen.png) | ![System prompt](./android/system-prompt-setup.png) | !["Haiku"](./android/chat-with-system-prompt-haiku.png) | | ![Home screen](https://naco-siren.github.io/ai-chat/policy/index/1-llm-starter-pack.png) | ![System prompt](https://naco-siren.github.io/ai-chat/policy/index/5-system-prompt.png) | !["Haiku"](https://naco-siren.github.io/ai-chat/policy/index/4-metrics.png) |
|:------------------------------------------------------:|:----------------------------------------------------:|:--------------------------------------------------------:| |:------------------------------------------------------:|:----------------------------------------------------:|:--------------------------------------------------------:|
| Home screen | System prompt | "Haiku" | | Home screen | System prompt | "Haiku" |
## Build on Android using Termux ## Build CLI on Android using Termux
[Termux](https://termux.dev/en/) is an Android terminal emulator and Linux environment app (no root required). As of writing, Termux is available experimentally in the Google Play Store; otherwise, it may be obtained directly from the project repo or on F-Droid. [Termux](https://termux.dev/en/) is an Android terminal emulator and Linux environment app (no root required). As of writing, Termux is available experimentally in the Google Play Store; otherwise, it may be obtained directly from the project repo or on F-Droid.
@ -52,7 +52,7 @@ To see what it might look like visually, here's an old demo of an interactive se
https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4 https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4
## Cross-compile using Android NDK ## Cross-compile CLI using Android NDK
It's possible to build `llama.cpp` for Android on your host system via CMake and the Android NDK. If you are interested in this path, ensure you already have an environment prepared to cross-compile programs for Android (i.e., install the Android SDK). Note that, unlike desktop environments, the Android environment ships with a limited set of native libraries, and so only those libraries are available to CMake when building with the Android NDK (see: https://developer.android.com/ndk/guides/stable_apis.) It's possible to build `llama.cpp` for Android on your host system via CMake and the Android NDK. If you are interested in this path, ensure you already have an environment prepared to cross-compile programs for Android (i.e., install the Android SDK). Note that, unlike desktop environments, the Android environment ships with a limited set of native libraries, and so only those libraries are available to CMake when building with the Android NDK (see: https://developer.android.com/ndk/guides/stable_apis.)
Once you're ready and have cloned `llama.cpp`, invoke the following in the project directory: Once you're ready and have cloned `llama.cpp`, invoke the following in the project directory:

Binary file not shown.

After

Width:  |  Height:  |  Size: 479 KiB

View File

@ -1,55 +1,57 @@
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android" <androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto" xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools" xmlns:tools="http://schemas.android.com/tools"
android:id="@+id/main" android:id="@+id/main"
android:layout_height="match_parent" android:layout_height="match_parent"
android:layout_width="match_parent"> android:layout_width="match_parent">
<LinearLayout <LinearLayout
android:fitsSystemWindows="true" android:fitsSystemWindows="true"
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="match_parent" android:layout_height="match_parent"
android:orientation="vertical" android:orientation="vertical"
android:layout_marginEnd="4dp"
tools:context=".MainActivity"> tools:context=".MainActivity">
<FrameLayout <ScrollView
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="0dp" android:layout_height="0dp"
android:layout_weight="1"> android:layout_weight="1"
android:fadeScrollbars="false">
<ScrollView <TextView
android:id="@+id/gguf"
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="wrap_content" android:layout_height="wrap_content"
android:fadeScrollbars="false"> android:layout_margin="16dp"
android:text="Selected GGUF model's metadata will show here."
style="@style/TextAppearance.MaterialComponents.Body2" />
<TextView </ScrollView>
android:id="@+id/gguf"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_margin="16dp"
android:text="Selected GGUF model's metadata will show here."
style="@style/TextAppearance.MaterialComponents.Body2"
android:maxLines="100" />
</ScrollView> <com.google.android.material.divider.MaterialDivider
android:layout_width="match_parent"
</FrameLayout> android:layout_height="2dp"
android:layout_marginHorizontal="16dp"
android:layout_marginVertical="8dp" />
<androidx.recyclerview.widget.RecyclerView <androidx.recyclerview.widget.RecyclerView
android:id="@+id/messages" android:id="@+id/messages"
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="0dp" android:layout_height="0dp"
android:layout_weight="4" android:layout_weight="4"
android:padding="16dp"
android:fadeScrollbars="false" android:fadeScrollbars="false"
android:scrollbars="vertical"
app:reverseLayout="true" app:reverseLayout="true"
tools:listitem="@layout/item_message_assistant"/> tools:listitem="@layout/item_message_assistant"/>
<LinearLayout <LinearLayout
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="wrap_content" android:layout_height="wrap_content"
android:orientation="horizontal"> android:orientation="horizontal"
android:paddingStart="16dp"
android:paddingEnd="4dp">
<EditText <EditText
android:id="@+id/user_input" android:id="@+id/user_input"
@ -67,7 +69,7 @@
style="@style/Widget.Material3.FloatingActionButton.Primary" style="@style/Widget.Material3.FloatingActionButton.Primary"
android:layout_width="wrap_content" android:layout_width="wrap_content"
android:layout_height="wrap_content" android:layout_height="wrap_content"
android:layout_margin="8dp" android:layout_margin="12dp"
android:src="@drawable/outline_folder_open_24" /> android:src="@drawable/outline_folder_open_24" />
</LinearLayout> </LinearLayout>

View File

@ -2,7 +2,8 @@
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android" <LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="wrap_content" android:layout_height="wrap_content"
android:padding="8dp" android:layout_marginHorizontal="16dp"
android:layout_marginVertical="8dp"
android:gravity="start"> android:gravity="start">
<TextView <TextView

View File

@ -2,7 +2,8 @@
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android" <LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="wrap_content" android:layout_height="wrap_content"
android:padding="8dp" android:layout_marginHorizontal="16dp"
android:layout_marginVertical="8dp"
android:gravity="end"> android:gravity="end">
<TextView <TextView

View File

@ -2,135 +2,22 @@
import argparse import argparse
import os import os
import sys
import importlib import importlib
from pathlib import Path from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
import torch import torch
import numpy as np import numpy as np
from utils.common import debug_hook
### If you want to dump RoPE activations, apply this monkey patch to the model
### class from Transformers that you are running (replace apertus.modeling_apertus
### with the proper package and class for your model
### === START ROPE DEBUG ===
# from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
# orig_rope = apply_rotary_pos_emb
# torch.set_printoptions(threshold=float('inf'))
# torch.set_printoptions(precision=6, sci_mode=False)
# def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
# # log inputs
# summarize(q, "RoPE.q_in")
# summarize(k, "RoPE.k_in")
# # call original
# q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
# # log outputs
# summarize(q_out, "RoPE.q_out")
# summarize(k_out, "RoPE.k_out")
# return q_out, k_out
# # Patch it
# import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
# apertus_mod.apply_rotary_pos_emb = debug_rope
### == END ROPE DEBUG ===
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
"""
Print a tensor in llama.cpp debug style.
Supports:
- 2D tensors (seq, hidden)
- 3D tensors (batch, seq, hidden)
- 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
Shows first and last max_vals of each vector per sequence position.
"""
t = tensor.detach().to(torch.float32).cpu()
# Determine dimensions
if t.ndim == 3:
_, s, _ = t.shape
elif t.ndim == 2:
_, s = 1, t.shape[0]
t = t.unsqueeze(0)
elif t.ndim == 4:
_, s, _, _ = t.shape
else:
print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
return
ten_shape = t.shape
print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
print(" [")
print(" [")
# Determine indices for first and last sequences
first_indices = list(range(min(s, max_seq)))
last_indices = list(range(max(0, s - max_seq), s))
# Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
# Combine indices
if has_overlap:
# If there's overlap, just use the combined unique indices
indices = sorted(list(set(first_indices + last_indices)))
separator_index = None
else:
# If no overlap, we'll add a separator between first and last sequences
indices = first_indices + last_indices
separator_index = len(first_indices)
for i, si in enumerate(indices):
# Add separator if needed
if separator_index is not None and i == separator_index:
print(" ...")
# Extract appropriate slice
vec = t[0, si]
if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
flat = vec.flatten().tolist()
else: # 2D or 3D case
flat = vec.tolist()
# First and last slices
first = flat[:max_vals]
last = flat[-max_vals:] if len(flat) >= max_vals else flat
first_str = ", ".join(f"{v:12.4f}" for v in first)
last_str = ", ".join(f"{v:12.4f}" for v in last)
print(f" [{first_str}, ..., {last_str}]")
print(" ],")
print(" ]")
print(f" sum = {t.sum().item():.6f}\n")
def debug_hook(name):
def fn(_m, input, output):
if isinstance(input, torch.Tensor):
summarize(input, name + "_in")
elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
summarize(input[0], name + "_in")
if isinstance(output, torch.Tensor):
summarize(output, name + "_out")
elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
summarize(output[0], name + "_out")
return fn
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
parser = argparse.ArgumentParser(description="Process model with specified path") parser = argparse.ArgumentParser(description="Process model with specified path")
parser.add_argument("--model-path", "-m", help="Path to the model") parser.add_argument("--model-path", "-m", help="Path to the model")
parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False) parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False)
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug output")
args = parser.parse_args() args = parser.parse_args()
model_path = os.environ.get("MODEL_PATH", args.model_path) model_path = os.environ.get("MODEL_PATH", args.model_path)
@ -139,6 +26,12 @@ if model_path is None:
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable" "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
) )
### If you want to dump RoPE activations, uncomment the following lines:
### === START ROPE DEBUG ===
# from utils.common import setup_rope_debug
# setup_rope_debug("transformers.models.apertus.modeling_apertus")
### == END ROPE DEBUG ===
print("Loading model and tokenizer using AutoTokenizer:", model_path) print("Loading model and tokenizer using AutoTokenizer:", model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
@ -156,6 +49,7 @@ print("Number of layers: ", config.num_hidden_layers)
print("BOS token id: ", config.bos_token_id) print("BOS token id: ", config.bos_token_id)
print("EOS token id: ", config.eos_token_id) print("EOS token id: ", config.eos_token_id)
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
if unreleased_model_name: if unreleased_model_name:
model_name_lower = unreleased_model_name.lower() model_name_lower = unreleased_model_name.lower()
unreleased_module_path = ( unreleased_module_path = (
@ -184,9 +78,10 @@ else:
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
) )
for name, module in model.named_modules(): if args.verbose:
if len(list(module.children())) == 0: # only leaf modules for name, module in model.named_modules():
module.register_forward_hook(debug_hook(name)) if len(list(module.children())) == 0: # only leaf modules
module.register_forward_hook(debug_hook(name))
model_name = os.path.basename(model_path) model_name = os.path.basename(model_path)
# Printing the Model class to allow for easier debugging. This can be useful # Printing the Model class to allow for easier debugging. This can be useful

View File

@ -2,6 +2,8 @@
import os import os
import sys import sys
import torch
def get_model_name_from_env_path(env_path_name): def get_model_name_from_env_path(env_path_name):
model_path = os.getenv(env_path_name) model_path = os.getenv(env_path_name)
@ -18,3 +20,131 @@ def get_model_name_from_env_path(env_path_name):
name = name[:-5] name = name[:-5]
return name return name
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
"""
Print a tensor in llama.cpp debug style.
Supports:
- 2D tensors (seq, hidden)
- 3D tensors (batch, seq, hidden)
- 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
Shows first and last max_vals of each vector per sequence position.
"""
t = tensor.detach().to(torch.float32).cpu()
# Determine dimensions
if t.ndim == 3:
_, s, _ = t.shape
elif t.ndim == 2:
_, s = 1, t.shape[0]
t = t.unsqueeze(0)
elif t.ndim == 4:
_, s, _, _ = t.shape
else:
print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
return
ten_shape = t.shape
print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
print(" [")
print(" [")
# Determine indices for first and last sequences
first_indices = list(range(min(s, max_seq)))
last_indices = list(range(max(0, s - max_seq), s))
# Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
# Combine indices
if has_overlap:
# If there's overlap, just use the combined unique indices
indices = sorted(list(set(first_indices + last_indices)))
separator_index = None
else:
# If no overlap, we'll add a separator between first and last sequences
indices = first_indices + last_indices
separator_index = len(first_indices)
for i, si in enumerate(indices):
# Add separator if needed
if separator_index is not None and i == separator_index:
print(" ...")
# Extract appropriate slice
vec = t[0, si]
if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
flat = vec.flatten().tolist()
else: # 2D or 3D case
flat = vec.tolist()
# First and last slices
first = flat[:max_vals]
last = flat[-max_vals:] if len(flat) >= max_vals else flat
first_str = ", ".join(f"{v:12.4f}" for v in first)
last_str = ", ".join(f"{v:12.4f}" for v in last)
print(f" [{first_str}, ..., {last_str}]")
print(" ],")
print(" ]")
print(f" sum = {t.sum().item():.6f}\n")
def debug_hook(name):
def fn(_m, input, output):
if isinstance(input, torch.Tensor):
summarize(input, name + "_in")
elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
summarize(input[0], name + "_in")
if isinstance(output, torch.Tensor):
summarize(output, name + "_out")
elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
summarize(output[0], name + "_out")
return fn
def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_pos_emb"):
"""
Apply monkey patch to dump RoPE activations for debugging.
Args:
model_module_path: Path to the model module (e.g., "transformers.models.apertus.modeling_apertus")
function_name: Name of the RoPE function to patch (default: "apply_rotary_pos_emb")
Example:
from utils.common import setup_rope_debug
setup_rope_debug("transformers.models.apertus.modeling_apertus")
"""
import importlib
# Import the module and get the original function
module = importlib.import_module(model_module_path)
orig_rope = getattr(module, function_name)
# Set torch print options for better debugging
torch.set_printoptions(threshold=float('inf'))
torch.set_printoptions(precision=6, sci_mode=False)
def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
# log inputs
summarize(q, "RoPE.q_in")
summarize(k, "RoPE.k_in")
# call original
q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
# log outputs
summarize(q_out, "RoPE.q_out")
summarize(k_out, "RoPE.k_out")
return q_out, k_out
# Patch it
setattr(module, function_name, debug_rope)
print(f"RoPE debug patching applied to {model_module_path}.{function_name}")

View File

@ -458,6 +458,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
if (GGML_RV_ZFH) if (GGML_RV_ZFH)
string(APPEND MARCH_STR "_zfh") string(APPEND MARCH_STR "_zfh")
endif() endif()
if (GGML_XTHEADVECTOR) if (GGML_XTHEADVECTOR)
string(APPEND MARCH_STR "_xtheadvector") string(APPEND MARCH_STR "_xtheadvector")
elseif (GGML_RVV) elseif (GGML_RVV)
@ -465,6 +466,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
if (GGML_RV_ZVFH) if (GGML_RV_ZVFH)
string(APPEND MARCH_STR "_zvfh") string(APPEND MARCH_STR "_zvfh")
endif() endif()
if (GGML_RV_ZVFBFWMA)
string(APPEND MARCH_STR "_zvfbfwma")
endif()
endif() endif()
if (GGML_RV_ZICBOP) if (GGML_RV_ZICBOP)
string(APPEND MARCH_STR "_zicbop") string(APPEND MARCH_STR "_zicbop")

View File

@ -3320,13 +3320,33 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) {
__m128 y_vec = _mm_cvtph_ps(x_vec); __m128 y_vec = _mm_cvtph_ps(x_vec);
_mm_storeu_ps(y + i, y_vec); _mm_storeu_ps(y + i, y_vec);
} }
#elif defined(__riscv_zvfh)
for (int vl; i < n; i += vl) { #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfhmin)
vl = __riscv_vsetvl_e16m1(n - i); // calculate step size
vfloat16m1_t vx = __riscv_vle16_v_f16m1((_Float16 *)&x[i], vl); const int epr = __riscv_vsetvlmax_e16m2();
vfloat32m2_t vy = __riscv_vfwcvt_f_f_v_f32m2(vx, vl); const int step = epr * 2;
__riscv_vse32_v_f32m2(&y[i], vy, vl); const int np = (n & ~(step - 1));
// unroll by 2
for (; i < np; i += step) {
vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16*)x + i, epr);
vfloat32m4_t ay0 = __riscv_vfwcvt_f_f_v_f32m4(ax0, epr);
__riscv_vse32_v_f32m4(y + i, ay0, epr);
vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16*)x + i + epr, epr);
vfloat32m4_t ay1 = __riscv_vfwcvt_f_f_v_f32m4(ax1, epr);
__riscv_vse32_v_f32m4(y + i + epr, ay1, epr);
} }
// leftovers
int vl;
for (i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m2(n - i);
vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16*)x + i, vl);
vfloat32m4_t ay0 = __riscv_vfwcvt_f_f_v_f32m4(ax0, vl);
__riscv_vse32_v_f32m4(y + i, ay0, vl);
}
#endif #endif
for (; i < n; ++i) { for (; i < n; ++i) {
@ -3371,6 +3391,31 @@ void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) {
(const __m128i *)(x + i))), (const __m128i *)(x + i))),
16))); 16)));
} }
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfmin)
// calculate step size
const int epr = __riscv_vsetvlmax_e16m2();
const int step = epr * 2;
const int np = (n & ~(step - 1));
// unroll by 2
for (; i < np; i += step) {
vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16*)x + i, epr);
vfloat32m4_t ay0 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax0, epr);
__riscv_vse32_v_f32m4(y + i, ay0, epr);
vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16*)x + i + epr, epr);
vfloat32m4_t ay1 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax1, epr);
__riscv_vse32_v_f32m4(y + i + epr, ay1, epr);
}
// leftovers
int vl;
for (i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m2(n - i);
vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16*)x + i, vl);
vfloat32m4_t ay0 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax0, vl);
__riscv_vse32_v_f32m4(y + i, ay0, vl);
}
#endif #endif
for (; i < n; i++) { for (; i < n; i++) {
y[i] = GGML_BF16_TO_FP32(x[i]); y[i] = GGML_BF16_TO_FP32(x[i]);

View File

@ -195,8 +195,48 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
sumf += (ggml_float)_mm_cvtss_f32(g); sumf += (ggml_float)_mm_cvtss_f32(g);
#undef LOAD #undef LOAD
#endif #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfwma)
size_t vl = __riscv_vsetvlmax_e32m4();
// initialize accumulators to all zeroes
vfloat32m4_t vsum0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
// calculate step size
const size_t epr = __riscv_vsetvlmax_e16m2();
const size_t step = epr * 2;
const int np = (n & ~(step - 1));
// unroll by 2
for (; i < np; i += step) {
vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], epr);
vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], epr);
vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, epr);
__asm__ __volatile__ ("" ::: "memory");
vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i + epr], epr);
vbfloat16m2_t ay1 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i + epr], epr);
vsum1 = __riscv_vfwmaccbf16_vv_f32m4(vsum1, ax1, ay1, epr);
__asm__ __volatile__ ("" ::: "memory");
}
// accumulate in 1 register
vsum0 = __riscv_vfadd_vv_f32m4(vsum0, vsum1, vl);
// leftovers
for (i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m2(n - i);
vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], vl);
vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], vl);
vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, vl);
}
// reduce
vl = __riscv_vsetvlmax_e32m4();
vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
#endif
for (; i < n; ++i) { for (; i < n; ++i) {
sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) * sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
GGML_BF16_TO_FP32(y[i])); GGML_BF16_TO_FP32(y[i]));

View File

@ -224,13 +224,71 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
} }
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03); GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13); GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
#elif defined(__riscv_v_intrinsic)
// todo: RVV impl #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
for (int i = 0; i < n; ++i) { size_t vl = __riscv_vsetvlmax_e32m4();
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); // initialize accumulators to all zeroes
} vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
} vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
// calculate step size
const size_t epr = __riscv_vsetvlmax_e16m2();
const size_t step = epr * 2;
const int np = (n & ~(step - 1));
// unroll by 2 along the row dimension
for (int i = 0; i < np; i += step) {
vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
}
vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
// leftovers
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m2(n - i);
vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
}
// reduce
vl = __riscv_vsetvlmax_e32m2();
vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
__riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
vl = __riscv_vsetvlmax_e32m1();
vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
__riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
vl = __riscv_vsetvlmax_e32m2();
vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
__riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
vl = __riscv_vsetvlmax_e32m1();
vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
__riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
#else #else
const int np = (n & ~(GGML_F16_STEP - 1)); const int np = (n & ~(GGML_F16_STEP - 1));
@ -475,15 +533,39 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
} }
np = n; np = n;
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic #elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
const int np = n; const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
_Float16 hv = (_Float16)v; const _Float16 scale = *(const _Float16*)(&s);
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e16m8(n - i); // calculate step size
vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl); const int epr = __riscv_vsetvlmax_e16m4();
vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl); const int step = epr * 2;
vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl); int np = (n & ~(step - 1));
__riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl);
// unroll by 2
for (int i = 0; i < np; i += step) {
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
__asm__ __volatile__ ("" ::: "memory");
vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
__asm__ __volatile__ ("" ::: "memory");
} }
// leftovers
int vl;
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m4(n - i);
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
}
np = n;
#elif defined(GGML_SIMD) #elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1)); const int np = (n & ~(GGML_F16_STEP - 1));
@ -724,13 +806,34 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
svst1_f16(pg, (__fp16 *)(y + np), out); svst1_f16(pg, (__fp16 *)(y + np), out);
} }
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
for (int i = 0, vl; i < n; i += vl) { const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
vl = __riscv_vsetvl_e16m2(n - i); const _Float16 scale = *(const _Float16*)(&s);
vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl);
vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl); // calculate step size
vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl); const int epr = __riscv_vsetvlmax_e16m4();
vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl); const int step = epr * 2;
__riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl); const int np = (n & ~(step - 1));
// unroll by 2
for (int i = 0; i < np; i += step) {
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
__asm__ __volatile__ ("" ::: "memory");
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
__asm__ __volatile__ ("" ::: "memory");
}
// leftovers
int vl;
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m4(n - i);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
} }
#elif defined(GGML_SIMD) #elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1)); const int np = (n & ~(GGML_F16_STEP - 1));

View File

@ -78,27 +78,25 @@ namespace ggml_cuda_mma {
// MIRRORED == Each data value is held exactly once per thread subgroup. // MIRRORED == Each data value is held exactly once per thread subgroup.
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA. DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3. DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
DATA_LAYOUT_I_MAJOR_MIRRORED = 20, DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
DATA_LAYOUT_J_MAJOR_MIRRORED = 30, DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
DATA_LAYOUT_I_MAJOR_DUAL = 40, // Matrix A&B for RDNA3.
}; };
// Implemented mma combinations are: // Implemented mma combinations are:
// - (I_MAJOR, I_MAJOR) -> I_MAJOR // - (I_MAJOR, I_MAJOR) -> I_MAJOR
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
constexpr bool is_i_major(const data_layout dl) { static constexpr bool is_i_major(const data_layout dl) {
return dl == DATA_LAYOUT_I_MAJOR || return dl == DATA_LAYOUT_I_MAJOR ||
dl == DATA_LAYOUT_I_MAJOR_MIRRORED || dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
dl == DATA_LAYOUT_I_MAJOR_DUAL;
} }
constexpr data_layout get_input_data_layout() { static constexpr __device__ data_layout get_input_data_layout() {
#if defined(RDNA3) #if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
return DATA_LAYOUT_I_MAJOR_DUAL; return DATA_LAYOUT_I_MAJOR_MIRRORED;
#else #else
return DATA_LAYOUT_I_MAJOR; return DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3) #endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
} }
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR> template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
@ -462,11 +460,65 @@ namespace ggml_cuda_mma {
} }
}; };
template <int I_, int J_, typename T>
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
// RDNA3
static constexpr int ne = I * J / 32 * 2;
T x[ne] = {0};
static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
if (I == 16 && J == 8) return true;
if (I == 16 && J == 4) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int /*l*/) {
if constexpr (supported()) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (supported()) {
return l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
};
template <int I_, int J_> template <int I_, int J_>
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> { struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
static constexpr int I = I_; static constexpr int I = I_;
static constexpr int J = J_; static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
#if defined(RDNA3)
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
}
static __device__ __forceinline__ int get_i(const int l) {
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
}
static __device__ __forceinline__ int get_j(const int l) {
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
}
#else // Volta
static constexpr int ne = I * J / (WARP_SIZE/4); static constexpr int ne = I * J / (WARP_SIZE/4);
half2 x[ne] = {{0.0f, 0.0f}}; half2 x[ne] = {{0.0f, 0.0f}};
@ -493,6 +545,29 @@ namespace ggml_cuda_mma {
return -1; return -1;
} }
} }
#endif // defined(RDNA3)
};
template <int I_, int J_>
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
}
static __device__ __forceinline__ int get_i(const int l) {
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
}
static __device__ __forceinline__ int get_j(const int l) {
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
}
}; };
template <int I_, int J_> template <int I_, int J_>
@ -528,42 +603,6 @@ namespace ggml_cuda_mma {
} }
}; };
template <int I_, int J_, typename T>
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_DUAL> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
static constexpr int ne = I * J / 32 * 2;
T x[ne] = {0};
static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
if (I == 16 && J == 8) return true;
if (I == 16 && J == 4) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (supported()) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (supported()) {
return l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
};
#if defined(TURING_MMA_AVAILABLE) #if defined(TURING_MMA_AVAILABLE)
template <int I, int J> template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) { static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {

View File

@ -102,31 +102,25 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
const int threads = 128; const int threads = 128;
GGML_ASSERT(nr % threads == 0); GGML_ASSERT(nr % threads == 0);
if (n_t <= 32) { auto launch_kernel = [&](auto NC) {
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); constexpr int kNC = decltype(NC)::value;
if (nc == 4) { if (n_t <= 32) {
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
dst, dst_nb0, dst_nb1, dst_nb2, n_t); ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
} else if (nc == 3) { dst, dst_nb0, dst_nb1, dst_nb2, n_t);
ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else { } else {
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
}
} else {
if (nc == 4) {
const int64_t split_n_t = 32; const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>( ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else if (nc == 3) {
const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
} }
};
switch (nc) {
case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
} }
} }

View File

@ -1527,6 +1527,8 @@ private:
#endif // GGML_VULKAN_MEMORY_DEBUG #endif // GGML_VULKAN_MEMORY_DEBUG
static bool vk_perf_logger_enabled = false; static bool vk_perf_logger_enabled = false;
static bool vk_perf_logger_concurrent = false;
static bool vk_enable_sync_logger = false;
// number of calls between perf logger prints // number of calls between perf logger prints
static uint32_t vk_perf_logger_frequency = 1; static uint32_t vk_perf_logger_frequency = 1;
@ -1577,14 +1579,14 @@ class vk_perf_logger {
flops.clear(); flops.clear();
} }
void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) { std::string get_node_fusion_name(const ggml_tensor * node, const char *fusion_name, uint64_t *n_flops) {
*n_flops = 0;
std::string fusion_str; std::string fusion_str;
if (fusion_name) { if (fusion_name) {
fusion_str = fusion_name + std::string(" "); fusion_str = fusion_name + std::string(" ");
} }
if (node->op == GGML_OP_UNARY) { if (node->op == GGML_OP_UNARY) {
timings[fusion_str + ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time); return fusion_str + ggml_unary_op_name(ggml_get_unary_op(node));
return;
} }
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
const uint64_t m = node->ne[0]; const uint64_t m = node->ne[0];
@ -1606,9 +1608,8 @@ class vk_perf_logger {
name += " batch=" + std::to_string(batch); name += " batch=" + std::to_string(batch);
} }
name = fusion_str + name; name = fusion_str + name;
timings[name].push_back(time); *n_flops = m * n * (k + (k - 1)) * batch;
flops[name].push_back(m * n * (k + (k - 1)) * batch); return name;
return;
} }
if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) { if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
std::string name = ggml_op_name(node->op); std::string name = ggml_op_name(node->op);
@ -1624,20 +1625,17 @@ class vk_perf_logger {
uint64_t size_M = Cout; uint64_t size_M = Cout;
uint64_t size_K = Cin * KW * KH; uint64_t size_K = Cin * KW * KH;
uint64_t size_N = N * OW * OH; uint64_t size_N = N * OW * OH;
uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1)); *n_flops = size_M * size_N * (size_K + (size_K - 1));
name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) + name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
", N=N*OW*OH=" + std::to_string(size_N); ", N=N*OW*OH=" + std::to_string(size_N);
name = fusion_str + name; name = fusion_str + name;
flops[name].push_back(n_flops); return name;
timings[name].push_back(time);
return;
} }
if (node->op == GGML_OP_RMS_NORM) { if (node->op == GGML_OP_RMS_NORM) {
std::string name = ggml_op_name(node->op); std::string name = ggml_op_name(node->op);
name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")"; name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
name = fusion_str + name; name = fusion_str + name;
timings[name].push_back(time); return name;
return;
} }
if (node->op == GGML_OP_FLASH_ATTN_EXT) { if (node->op == GGML_OP_FLASH_ATTN_EXT) {
const ggml_tensor * dst = node; const ggml_tensor * dst = node;
@ -1653,8 +1651,7 @@ class vk_perf_logger {
" k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " << " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
" v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " << " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
" m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")"; " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
timings[name.str()].push_back(time); return name.str();
return;
} }
if (node->op == GGML_OP_TOP_K) { if (node->op == GGML_OP_TOP_K) {
std::stringstream name; std::stringstream name;
@ -1662,11 +1659,38 @@ class vk_perf_logger {
name << ggml_op_name(node->op) << name << ggml_op_name(node->op) <<
" K=" << node->ne[0] << " K=" << node->ne[0] <<
" (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")"; " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
timings[name.str()].push_back(time); return name.str();
return;
} }
timings[fusion_str + ggml_op_name(node->op)].push_back(time); return fusion_str + ggml_op_name(node->op);
} }
void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) {
uint64_t n_flops;
std::string name = get_node_fusion_name(node, fusion_name, &n_flops);
if (n_flops) {
flops[name].push_back(n_flops);
}
timings[name].push_back(time);
}
void log_timing(const std::vector<ggml_tensor *> &nodes, const std::vector<const char *> &names, uint64_t time) {
uint64_t total_flops = 0;
std::string name;
for (size_t n = 0; n < nodes.size(); ++n) {
uint64_t n_flops = 0;
name += get_node_fusion_name(nodes[n], names[n], &n_flops);
total_flops += n_flops;
if (n != nodes.size() - 1) {
name += ", ";
}
}
if (total_flops) {
flops[name].push_back(total_flops);
}
timings[name].push_back(time);
}
private: private:
std::map<std::string, std::vector<uint64_t>> timings; std::map<std::string, std::vector<uint64_t>> timings;
std::map<std::string, std::vector<uint64_t>> flops; std::map<std::string, std::vector<uint64_t>> flops;
@ -1729,7 +1753,9 @@ struct ggml_backend_vk_context {
std::unique_ptr<vk_perf_logger> perf_logger; std::unique_ptr<vk_perf_logger> perf_logger;
vk::QueryPool query_pool; vk::QueryPool query_pool;
std::vector<const char *> query_fusion_names; std::vector<const char *> query_fusion_names;
std::vector<int> query_fusion_node_count;
std::vector<ggml_tensor *> query_nodes; std::vector<ggml_tensor *> query_nodes;
std::vector<int> query_node_idx;
int32_t num_queries {}; int32_t num_queries {};
int32_t query_idx {}; int32_t query_idx {};
}; };
@ -5194,6 +5220,8 @@ static void ggml_vk_instance_init() {
} }
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr;
vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr;
const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY"); const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY");
if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) { if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) {
@ -11820,15 +11848,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
} }
} }
#define ENABLE_SYNC_LOGGING 0
if (need_sync) { if (need_sync) {
#if ENABLE_SYNC_LOGGING if (vk_enable_sync_logger) {
std::cerr << "sync" << std::endl; std::cerr << "sync" << std::endl;
#endif }
ctx->unsynced_nodes_written.clear(); ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear(); ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx); ggml_vk_sync_buffers(ctx, compute_ctx);
if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {
ctx->query_node_idx[ctx->query_idx] = node_idx;
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
}
} }
// Add all fused nodes to the unsynchronized lists. // Add all fused nodes to the unsynchronized lists.
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
@ -11845,20 +11876,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
} }
} }
} }
#if ENABLE_SYNC_LOGGING if (vk_enable_sync_logger) {
for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
auto *n = cgraph->nodes[node_idx + i]; auto *n = cgraph->nodes[node_idx + i];
std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name; std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
if (n->op == GGML_OP_GLU) { if (n->op == GGML_OP_GLU) {
std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " "; std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
}
if (n->op == GGML_OP_ROPE) {
const int mode = ((const int32_t *) n->op_params)[2];
std::cerr << " rope mode: " << mode;
}
std::cerr << std::endl;
} }
if (n->op == GGML_OP_ROPE) {
const int mode = ((const int32_t *) n->op_params)[2];
std::cerr << " rope mode: " << mode;
}
std::cerr << std::endl;
} }
#endif
switch (node->op) { switch (node->op) {
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
@ -13138,12 +13169,16 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->query_pool = ctx->device->device.createQueryPool(query_create_info); ctx->query_pool = ctx->device->device.createQueryPool(query_create_info);
ctx->num_queries = query_create_info.queryCount; ctx->num_queries = query_create_info.queryCount;
ctx->query_fusion_names.resize(ctx->num_queries); ctx->query_fusion_names.resize(ctx->num_queries);
ctx->query_fusion_node_count.resize(ctx->num_queries);
ctx->query_nodes.resize(ctx->num_queries); ctx->query_nodes.resize(ctx->num_queries);
ctx->query_node_idx.resize(ctx->num_queries);
} }
ctx->device->device.resetQueryPool(ctx->query_pool, 0, cgraph->n_nodes+1); ctx->device->device.resetQueryPool(ctx->query_pool, 0, cgraph->n_nodes+1);
std::fill(ctx->query_fusion_names.begin(), ctx->query_fusion_names.end(), nullptr); std::fill(ctx->query_fusion_names.begin(), ctx->query_fusion_names.end(), nullptr);
std::fill(ctx->query_fusion_node_count.begin(), ctx->query_fusion_node_count.end(), 0);
std::fill(ctx->query_nodes.begin(), ctx->query_nodes.end(), nullptr); std::fill(ctx->query_nodes.begin(), ctx->query_nodes.end(), nullptr);
std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0);
GGML_ASSERT(ctx->compute_ctx.expired()); GGML_ASSERT(ctx->compute_ctx.expired());
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
@ -13272,9 +13307,16 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
} else { } else {
compute_ctx = ctx->compute_ctx.lock(); compute_ctx = ctx->compute_ctx.lock();
} }
ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i]; if (!vk_perf_logger_concurrent) {
ctx->query_fusion_names[ctx->query_idx] = fusion_string; // track a single node/fusion for the current query
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
ctx->query_fusion_names[ctx->query_idx] = fusion_string;
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
} else {
// track a fusion string and number of fused ops for the current node_idx
ctx->query_fusion_names[i] = fusion_string;
ctx->query_fusion_node_count[i] = ctx->num_additional_fused_ops;
}
} }
if (enqueued) { if (enqueued) {
@ -13316,12 +13358,32 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
// Get the results and pass them to the logger // Get the results and pass them to the logger
std::vector<uint64_t> timestamps(cgraph->n_nodes + 1); std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->query_pool, 0, ctx->query_idx, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results"); VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->query_pool, 0, ctx->query_idx, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results");
for (int i = 1; i < ctx->query_idx; i++) { if (!vk_perf_logger_concurrent) {
auto node = ctx->query_nodes[i]; // Log each op separately
auto name = ctx->query_fusion_names[i]; for (int i = 1; i < ctx->query_idx; i++) {
ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod)); auto node = ctx->query_nodes[i];
auto name = ctx->query_fusion_names[i];
ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
}
} else {
// Log each group of nodes
int prev_node_idx = 0;
for (int i = 1; i < ctx->query_idx; i++) {
auto cur_node_idx = ctx->query_node_idx[i];
std::vector<ggml_tensor *> nodes;
std::vector<const char *> names;
for (int node_idx = prev_node_idx; node_idx < cur_node_idx; ++node_idx) {
if (ggml_op_is_empty(cgraph->nodes[node_idx]->op)) {
continue;
}
nodes.push_back(cgraph->nodes[node_idx]);
names.push_back(ctx->query_fusion_names[node_idx]);
node_idx += ctx->query_fusion_node_count[node_idx];
}
prev_node_idx = cur_node_idx;
ctx->perf_logger->log_timing(nodes, names, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
}
} }
ctx->perf_logger->print_timings(); ctx->perf_logger->print_timings();
} }

View File

@ -690,6 +690,8 @@ class MODEL_TENSOR(IntEnum):
V_TOK_EOI = auto() # cogvlm V_TOK_EOI = auto() # cogvlm
# audio (mtmd) # audio (mtmd)
A_ENC_EMBD_POS = auto() A_ENC_EMBD_POS = auto()
A_ENC_EMBD_NORM = auto()
A_ENC_EMBD_TO_LOGITS = auto()
A_ENC_CONV1D = auto() A_ENC_CONV1D = auto()
A_PRE_NORM = auto() A_PRE_NORM = auto()
A_POST_NORM = auto() A_POST_NORM = auto()
@ -700,8 +702,13 @@ class MODEL_TENSOR(IntEnum):
A_ENC_OUTPUT = auto() A_ENC_OUTPUT = auto()
A_ENC_OUTPUT_NORM = auto() A_ENC_OUTPUT_NORM = auto()
A_ENC_FFN_UP = auto() A_ENC_FFN_UP = auto()
A_ENC_FFN_NORM = auto()
A_ENC_FFN_GATE = auto() A_ENC_FFN_GATE = auto()
A_ENC_FFN_DOWN = auto() A_ENC_FFN_DOWN = auto()
A_ENC_FFN_UP_1 = auto()
A_ENC_FFN_NORM_1 = auto()
A_ENC_FFN_GATE_1 = auto()
A_ENC_FFN_DOWN_1 = auto()
A_MMPROJ = auto() A_MMPROJ = auto()
A_MMPROJ_FC = auto() A_MMPROJ_FC = auto()
A_MM_NORM_PRE = auto() A_MM_NORM_PRE = auto()
@ -713,6 +720,16 @@ class MODEL_TENSOR(IntEnum):
NEXTN_HNORM = auto() NEXTN_HNORM = auto()
NEXTN_SHARED_HEAD_HEAD = auto() NEXTN_SHARED_HEAD_HEAD = auto()
NEXTN_SHARED_HEAD_NORM = auto() NEXTN_SHARED_HEAD_NORM = auto()
# lfm2 audio
A_ENC_NORM_CONV = auto()
A_ENC_LINEAR_POS = auto()
A_ENC_POS_BIAS_U = auto()
A_ENC_POS_BIAS_V = auto()
A_ENC_OUT = auto()
A_ENC_CONV_DW = auto() # SSM conv
A_ENC_CONV_NORM = auto() # SSM conv
A_ENC_CONV_PW1 = auto()
A_ENC_CONV_PW2 = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -1064,7 +1081,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_TOK_BOI: "v.boi", MODEL_TENSOR.V_TOK_BOI: "v.boi",
MODEL_TENSOR.V_TOK_EOI: "v.eoi", MODEL_TENSOR.V_TOK_EOI: "v.eoi",
# audio (mtmd) # audio (mtmd)
# note: all audio tensor names must use prefix "a." or "mm.a."
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
MODEL_TENSOR.A_ENC_EMBD_NORM: "a.position_embd_norm",
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits",
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln", MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
MODEL_TENSOR.A_POST_NORM: "a.post_ln", MODEL_TENSOR.A_POST_NORM: "a.post_ln",
@ -1074,13 +1094,28 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1", MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1",
MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out", MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out",
MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2", MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2",
MODEL_TENSOR.A_ENC_FFN_NORM: "a.blk.{bid}.ffn_norm",
MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up", MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up",
MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate", MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate",
MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down", MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down",
MODEL_TENSOR.A_ENC_FFN_NORM_1: "a.blk.{bid}.ffn_norm_1",
MODEL_TENSOR.A_ENC_FFN_UP_1: "a.blk.{bid}.ffn_up_1",
MODEL_TENSOR.A_ENC_FFN_GATE_1: "a.blk.{bid}.ffn_gate_1",
MODEL_TENSOR.A_ENC_FFN_DOWN_1: "a.blk.{bid}.ffn_down_1",
MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}", MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}",
MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc", MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc",
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre", MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid", MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
# lfm2 audio
MODEL_TENSOR.A_ENC_NORM_CONV: "a.blk.{bid}.norm_conv",
MODEL_TENSOR.A_ENC_LINEAR_POS: "a.blk.{bid}.linear_pos",
MODEL_TENSOR.A_ENC_POS_BIAS_U: "a.blk.{bid}.pos_bias_u",
MODEL_TENSOR.A_ENC_POS_BIAS_V: "a.blk.{bid}.pos_bias_v",
MODEL_TENSOR.A_ENC_OUT: "a.pre_encode.out",
MODEL_TENSOR.A_ENC_CONV_DW: "a.blk.{bid}.conv_dw",
MODEL_TENSOR.A_ENC_CONV_NORM: "a.blk.{bid}.conv_norm",
MODEL_TENSOR.A_ENC_CONV_PW1: "a.blk.{bid}.conv_pw1",
MODEL_TENSOR.A_ENC_CONV_PW2: "a.blk.{bid}.conv_pw2",
# NextN/MTP # NextN/MTP
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj", MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj",
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens", MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens",
@ -1145,6 +1180,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_TOK_EOI, MODEL_TENSOR.V_TOK_EOI,
# audio # audio
MODEL_TENSOR.A_ENC_EMBD_POS, MODEL_TENSOR.A_ENC_EMBD_POS,
MODEL_TENSOR.A_ENC_EMBD_NORM,
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS,
MODEL_TENSOR.A_ENC_CONV1D, MODEL_TENSOR.A_ENC_CONV1D,
MODEL_TENSOR.A_PRE_NORM, MODEL_TENSOR.A_PRE_NORM,
MODEL_TENSOR.A_POST_NORM, MODEL_TENSOR.A_POST_NORM,
@ -1154,13 +1191,27 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.A_ENC_INPUT_NORM, MODEL_TENSOR.A_ENC_INPUT_NORM,
MODEL_TENSOR.A_ENC_OUTPUT, MODEL_TENSOR.A_ENC_OUTPUT,
MODEL_TENSOR.A_ENC_OUTPUT_NORM, MODEL_TENSOR.A_ENC_OUTPUT_NORM,
MODEL_TENSOR.A_ENC_FFN_NORM,
MODEL_TENSOR.A_ENC_FFN_UP, MODEL_TENSOR.A_ENC_FFN_UP,
MODEL_TENSOR.A_ENC_FFN_GATE, MODEL_TENSOR.A_ENC_FFN_GATE,
MODEL_TENSOR.A_ENC_FFN_DOWN, MODEL_TENSOR.A_ENC_FFN_DOWN,
MODEL_TENSOR.A_ENC_FFN_NORM_1,
MODEL_TENSOR.A_ENC_FFN_UP_1,
MODEL_TENSOR.A_ENC_FFN_GATE_1,
MODEL_TENSOR.A_ENC_FFN_DOWN_1,
MODEL_TENSOR.A_MMPROJ, MODEL_TENSOR.A_MMPROJ,
MODEL_TENSOR.A_MMPROJ_FC, MODEL_TENSOR.A_MMPROJ_FC,
MODEL_TENSOR.A_MM_NORM_PRE, MODEL_TENSOR.A_MM_NORM_PRE,
MODEL_TENSOR.A_MM_NORM_MID, MODEL_TENSOR.A_MM_NORM_MID,
MODEL_TENSOR.A_ENC_NORM_CONV,
MODEL_TENSOR.A_ENC_LINEAR_POS,
MODEL_TENSOR.A_ENC_POS_BIAS_U,
MODEL_TENSOR.A_ENC_POS_BIAS_V,
MODEL_TENSOR.A_ENC_OUT,
MODEL_TENSOR.A_ENC_CONV_DW,
MODEL_TENSOR.A_ENC_CONV_NORM,
MODEL_TENSOR.A_ENC_CONV_PW1,
MODEL_TENSOR.A_ENC_CONV_PW2,
], ],
MODEL_ARCH.LLAMA: [ MODEL_ARCH.LLAMA: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
@ -3363,6 +3414,7 @@ class VisionProjectorType:
LIGHTONOCR = "lightonocr" LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm" COGVLM = "cogvlm"
JANUS_PRO = "janus_pro" JANUS_PRO = "janus_pro"
LFM2A = "lfm2a" # audio
GLM4V = "glm4v" GLM4V = "glm4v"

View File

@ -1535,10 +1535,20 @@ class TensorNameMap:
MODEL_TENSOR.A_ENC_EMBD_POS: ( MODEL_TENSOR.A_ENC_EMBD_POS: (
"audio_tower.embed_positions", # ultravox "audio_tower.embed_positions", # ultravox
"audio_embedding.embedding", # lfm2
),
MODEL_TENSOR.A_ENC_EMBD_NORM: (
"audio_embedding.embedding_norm", # lfm2
),
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: (
"audio_embedding.to_logits", # lfm2
), ),
MODEL_TENSOR.A_ENC_CONV1D: ( MODEL_TENSOR.A_ENC_CONV1D: (
"audio_tower.conv{bid}", # ultravox "audio_tower.conv{bid}", # ultravox
"conformer.pre_encode.conv.{bid}", # lfm2
), ),
MODEL_TENSOR.A_PRE_NORM: (), MODEL_TENSOR.A_PRE_NORM: (),
@ -1550,36 +1560,76 @@ class TensorNameMap:
MODEL_TENSOR.A_ENC_ATTN_Q: ( MODEL_TENSOR.A_ENC_ATTN_Q: (
"audio_tower.layers.{bid}.self_attn.q_proj", # ultravox "audio_tower.layers.{bid}.self_attn.q_proj", # ultravox
"conformer.layers.{bid}.self_attn.linear_q", # lfm2
), ),
MODEL_TENSOR.A_ENC_ATTN_K: ( MODEL_TENSOR.A_ENC_ATTN_K: (
"audio_tower.layers.{bid}.self_attn.k_proj", # ultravox "audio_tower.layers.{bid}.self_attn.k_proj", # ultravox
"conformer.layers.{bid}.self_attn.linear_k", # lfm2
), ),
MODEL_TENSOR.A_ENC_ATTN_V: ( MODEL_TENSOR.A_ENC_ATTN_V: (
"audio_tower.layers.{bid}.self_attn.v_proj", # ultravox "audio_tower.layers.{bid}.self_attn.v_proj", # ultravox
"conformer.layers.{bid}.self_attn.linear_v", # lfm2
), ),
MODEL_TENSOR.A_ENC_INPUT_NORM: ( MODEL_TENSOR.A_ENC_INPUT_NORM: (
"audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox "audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox
"conformer.layers.{bid}.norm_self_att", # lfm2
), ),
MODEL_TENSOR.A_ENC_OUTPUT: ( MODEL_TENSOR.A_ENC_OUTPUT: (
"audio_tower.layers.{bid}.self_attn.out_proj", # ultravox "audio_tower.layers.{bid}.self_attn.out_proj", # ultravox
"conformer.layers.{bid}.self_attn.linear_out", # lfm2
), ),
MODEL_TENSOR.A_ENC_OUTPUT_NORM: ( MODEL_TENSOR.A_ENC_OUTPUT_NORM: (
"audio_tower.layers.{bid}.final_layer_norm", # ultravox "audio_tower.layers.{bid}.final_layer_norm", # ultravox
"conformer.layers.{bid}.norm_out", # lfm2
),
MODEL_TENSOR.A_ENC_FFN_NORM: (
"conformer.layers.{bid}.norm_feed_forward1", # lfm2
), ),
MODEL_TENSOR.A_ENC_FFN_UP: ( MODEL_TENSOR.A_ENC_FFN_UP: (
"audio_tower.layers.{bid}.fc1", # ultravox "audio_tower.layers.{bid}.fc1", # ultravox
"conformer.layers.{bid}.feed_forward1.linear1", # lfm2
), ),
MODEL_TENSOR.A_ENC_FFN_GATE: (), MODEL_TENSOR.A_ENC_FFN_GATE: (),
MODEL_TENSOR.A_ENC_FFN_DOWN: ( MODEL_TENSOR.A_ENC_FFN_DOWN: (
"audio_tower.layers.{bid}.fc2", # ultravox "audio_tower.layers.{bid}.fc2", # ultravox
"conformer.layers.{bid}.feed_forward1.linear2", # lfm2
),
MODEL_TENSOR.A_ENC_FFN_UP_1: (
"conformer.layers.{bid}.feed_forward2.linear1", # lfm2
),
MODEL_TENSOR.A_ENC_FFN_DOWN_1: (
"conformer.layers.{bid}.feed_forward2.linear2", # lfm2
),
MODEL_TENSOR.A_ENC_FFN_NORM_1: (
"conformer.layers.{bid}.norm_feed_forward2", # lfm2
),
MODEL_TENSOR.A_ENC_LINEAR_POS: (
"conformer.layers.{bid}.self_attn.linear_pos", # lfm2
),
MODEL_TENSOR.A_ENC_POS_BIAS_U: (
"conformer.layers.{bid}.self_attn.pos_bias_u", # lfm2
),
MODEL_TENSOR.A_ENC_POS_BIAS_V: (
"conformer.layers.{bid}.self_attn.pos_bias_v", # lfm2
),
MODEL_TENSOR.A_ENC_OUT: (
"conformer.pre_encode.out", # lfm2
), ),
# note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors # note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors
@ -1587,6 +1637,7 @@ class TensorNameMap:
MODEL_TENSOR.A_MMPROJ: ( MODEL_TENSOR.A_MMPROJ: (
"audio.multi_modal_projector.linear_{bid}", # ultravox "audio.multi_modal_projector.linear_{bid}", # ultravox
"audio_adapter.model.{bid}" # lfm2
), ),
MODEL_TENSOR.A_MMPROJ_FC: ( MODEL_TENSOR.A_MMPROJ_FC: (
@ -1602,6 +1653,26 @@ class TensorNameMap:
"audio.multi_modal_projector.ln_mid", # ultravox "audio.multi_modal_projector.ln_mid", # ultravox
), ),
MODEL_TENSOR.A_ENC_CONV_DW: (
"conformer.layers.{bid}.conv.depthwise_conv", # lfm2
),
MODEL_TENSOR.A_ENC_CONV_NORM: (
"conformer.layers.{bid}.conv.batch_norm", # lfm2
),
MODEL_TENSOR.A_ENC_CONV_PW1: (
"conformer.layers.{bid}.conv.pointwise_conv1", # lfm2
),
MODEL_TENSOR.A_ENC_CONV_PW2: (
"conformer.layers.{bid}.conv.pointwise_conv2", # lfm2
),
MODEL_TENSOR.A_ENC_NORM_CONV: (
"conformer.layers.{bid}.norm_conv", # lfm2
),
# NextN/MTP tensors for GLM4_MOE # NextN/MTP tensors for GLM4_MOE
MODEL_TENSOR.NEXTN_EH_PROJ: ( MODEL_TENSOR.NEXTN_EH_PROJ: (
"model.layers.{bid}.eh_proj", "model.layers.{bid}.eh_proj",

View File

@ -288,7 +288,7 @@ class LocalTensor:
data_range: LocalTensorRange data_range: LocalTensorRange
def mmap_bytes(self) -> np.ndarray: def mmap_bytes(self) -> np.ndarray:
return np.memmap(self.data_range.filename, mode='r', offset=self.data_range.offset, shape=self.data_range.size) return np.memmap(self.data_range.filename, mode='c', offset=self.data_range.offset, shape=self.data_range.size)
class SafetensorsLocal: class SafetensorsLocal:

View File

@ -7295,11 +7295,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f)); test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
for (int64_t d_conv : {3, 4}) { for (int64_t d_conv : {3, 4, 9}) {
for (int64_t d_inner: {1024, 1536, 2048}) { for (int64_t d_inner: {1024, 1536, 2048}) {
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, d_inner, 1, 1}, {d_conv, d_inner, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, d_inner, 1, 1}, {d_conv, d_inner, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, d_inner, 4, 1}, {d_conv, d_inner, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}));
} }
} }

View File

@ -15,6 +15,7 @@ add_library(mtmd
clip-graph.h clip-graph.h
models/models.h models/models.h
models/cogvlm.cpp models/cogvlm.cpp
models/conformer.cpp
models/glm4v.cpp models/glm4v.cpp
models/internvl.cpp models/internvl.cpp
models/kimivl.cpp models/kimivl.cpp

View File

@ -138,6 +138,21 @@
#define TN_TOK_BOI "v.boi" #define TN_TOK_BOI "v.boi"
#define TN_TOK_EOI "v.eoi" #define TN_TOK_EOI "v.eoi"
// (conformer) lfm2
#define TN_PRE_ENCODE_OUT "a.pre_encode.out.%s"
#define TN_FFN_NORM "%s.blk.%d.ffn_norm.%s"
#define TN_FFN_NORM_1 "%s.blk.%d.ffn_norm_1.%s"
#define TN_FFN_UP_1 "%s.blk.%d.ffn_up_1.%s"
#define TN_FFN_DOWN_1 "%s.blk.%d.ffn_down_1.%s"
#define TN_POS_BIAS_U "%s.blk.%d.pos_bias_u"
#define TN_POS_BIAS_V "%s.blk.%d.pos_bias_v"
#define TN_NORM_CONV "%s.blk.%d.norm_conv.%s"
#define TN_LINEAR_POS "%s.blk.%d.linear_pos.%s"
#define TN_CONV_DW "%s.blk.%d.conv_dw.%s"
#define TN_CONV_NORM "%s.blk.%d.conv_norm.%s"
#define TN_CONV_PW1 "%s.blk.%d.conv_pw1.%s"
#define TN_CONV_PW2 "%s.blk.%d.conv_pw2.%s"
// align x to upper multiple of n // align x to upper multiple of n
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
@ -170,6 +185,7 @@ enum projector_type {
PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_LIGHTONOCR,
PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_COGVLM,
PROJECTOR_TYPE_JANUS_PRO, PROJECTOR_TYPE_JANUS_PRO,
PROJECTOR_TYPE_LFM2A,
PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_GLM4V,
PROJECTOR_TYPE_UNKNOWN, PROJECTOR_TYPE_UNKNOWN,
}; };
@ -198,6 +214,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
{ PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"},
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
{ PROJECTOR_TYPE_LFM2A, "lfm2a"},
{ PROJECTOR_TYPE_GLM4V, "glm4v"}, { PROJECTOR_TYPE_GLM4V, "glm4v"},
}; };

View File

@ -4,6 +4,7 @@
#include "clip.h" #include "clip.h"
#include "clip-impl.h" #include "clip-impl.h"
#include <array>
#include <vector> #include <vector>
#include <unordered_set> #include <unordered_set>
#include <cstdint> #include <cstdint>
@ -142,6 +143,30 @@ struct clip_layer {
ggml_tensor * deepstack_fc2_w = nullptr; ggml_tensor * deepstack_fc2_w = nullptr;
ggml_tensor * deepstack_fc2_b = nullptr; ggml_tensor * deepstack_fc2_b = nullptr;
// lfm2
ggml_tensor * ff_norm_w = nullptr;
ggml_tensor * ff_norm_b = nullptr;
ggml_tensor * ff_norm_1_w = nullptr;
ggml_tensor * ff_norm_1_b = nullptr;
ggml_tensor * ff_up_1_w = nullptr;
ggml_tensor * ff_up_1_b = nullptr;
ggml_tensor * ff_down_1_w = nullptr;
ggml_tensor * ff_down_1_b = nullptr;
ggml_tensor * pos_bias_u = nullptr;
ggml_tensor * pos_bias_v = nullptr;
ggml_tensor * norm_conv_w = nullptr;
ggml_tensor * norm_conv_b = nullptr;
ggml_tensor * linear_pos_w = nullptr;
ggml_tensor * conv_norm_w = nullptr;
ggml_tensor * conv_norm_b = nullptr;
ggml_tensor * conv_dw_w = nullptr;
ggml_tensor * conv_dw_b = nullptr;
ggml_tensor * conv_pw1_w = nullptr;
ggml_tensor * conv_pw1_b = nullptr;
ggml_tensor * conv_pw2_w = nullptr;
ggml_tensor * conv_pw2_b = nullptr;
bool has_deepstack() const { bool has_deepstack() const {
return deepstack_fc1_w != nullptr; return deepstack_fc1_w != nullptr;
} }
@ -286,6 +311,12 @@ struct clip_model {
ggml_tensor * mm_boi = nullptr; ggml_tensor * mm_boi = nullptr;
ggml_tensor * mm_eoi = nullptr; ggml_tensor * mm_eoi = nullptr;
// lfm2 audio
std::array<ggml_tensor *, 7> pre_encode_conv_X_w = {nullptr};
std::array<ggml_tensor *, 7> pre_encode_conv_X_b = {nullptr};
ggml_tensor * pre_encode_out_w = nullptr;
ggml_tensor * pre_encode_out_b = nullptr;
bool audio_has_avgpool() const { bool audio_has_avgpool() const {
return proj_type == PROJECTOR_TYPE_QWEN2A return proj_type == PROJECTOR_TYPE_QWEN2A
|| proj_type == PROJECTOR_TYPE_VOXTRAL; || proj_type == PROJECTOR_TYPE_VOXTRAL;

View File

@ -837,6 +837,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{ {
builder = std::make_unique<clip_graph_llava>(ctx, img); builder = std::make_unique<clip_graph_llava>(ctx, img);
} break; } break;
case PROJECTOR_TYPE_LFM2A:
{
builder = std::make_unique<clip_graph_conformer>(ctx, img);
} break;
case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_GLM4V:
{ {
builder = std::make_unique<clip_graph_glm4v>(ctx, img); builder = std::make_unique<clip_graph_glm4v>(ctx, img);
@ -1187,6 +1191,15 @@ struct clip_model_loader {
hparams.audio_window_len = 400; hparams.audio_window_len = 400;
hparams.audio_hop_len = 160; hparams.audio_hop_len = 160;
} break; } break;
case PROJECTOR_TYPE_LFM2A:
{
// audio preprocessing params
hparams.audio_chunk_len = 1; // in seconds
hparams.audio_sample_rate = 16000;
hparams.audio_n_fft = 512;
hparams.audio_window_len = 400;
hparams.audio_hop_len = 160;
} break;
default: default:
break; break;
} }
@ -1611,6 +1624,52 @@ struct clip_model_loader {
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
} break; } break;
case PROJECTOR_TYPE_LFM2A:
{
for (int i : {0, 2, 3, 5, 6}) {
model.pre_encode_conv_X_w[i] = get_tensor(string_format(TN_CONV1D, i, "weight"));
model.pre_encode_conv_X_b[i] = get_tensor(string_format(TN_CONV1D, i, "bias"));
}
model.pre_encode_out_w = get_tensor(string_format(TN_PRE_ENCODE_OUT, "weight"));
model.pre_encode_out_b = get_tensor(string_format(TN_PRE_ENCODE_OUT, "bias"));
model.mm_0_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 0, "weight"));
model.mm_0_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 0, "bias"));
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias"));
model.mm_3_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 3, "weight"));
model.mm_3_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 3, "bias"));
for (int il = 0; il < hparams.n_layer; ++il) {
auto & layer = model.layers[il];
layer.ff_norm_w = get_tensor(string_format(TN_FFN_NORM, prefix, il, "weight"));
layer.ff_norm_b = get_tensor(string_format(TN_FFN_NORM, prefix, il, "bias"));
layer.ff_norm_1_w = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "weight"));
layer.ff_norm_1_b = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "bias"));
layer.ff_up_1_w = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "weight"));
layer.ff_up_1_b = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "bias"));
layer.ff_down_1_w = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "weight"));
layer.ff_down_1_b = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "bias"));
layer.pos_bias_u = get_tensor(string_format(TN_POS_BIAS_U, prefix, il));
layer.pos_bias_v = get_tensor(string_format(TN_POS_BIAS_V, prefix, il));
layer.norm_conv_w = get_tensor(string_format(TN_NORM_CONV, prefix, il, "weight"));
layer.norm_conv_b = get_tensor(string_format(TN_NORM_CONV, prefix, il, "bias"));
layer.linear_pos_w = get_tensor(string_format(TN_LINEAR_POS, prefix, il, "weight"));
layer.conv_norm_w = get_tensor(string_format(TN_CONV_NORM, prefix, il, "weight"));
layer.conv_norm_b = get_tensor(string_format(TN_CONV_NORM, prefix, il, "bias"));
layer.conv_dw_w = get_tensor(string_format(TN_CONV_DW, prefix, il, "weight"));
layer.conv_dw_b = get_tensor(string_format(TN_CONV_DW, prefix, il, "bias"));
layer.conv_pw1_w = get_tensor(string_format(TN_CONV_PW1, prefix, il, "weight"));
layer.conv_pw1_b = get_tensor(string_format(TN_CONV_PW1, prefix, il, "bias"));
layer.conv_pw2_w = get_tensor(string_format(TN_CONV_PW2, prefix, il, "weight"));
layer.conv_pw2_b = get_tensor(string_format(TN_CONV_PW2, prefix, il, "bias"));
}
} break;
default: default:
GGML_ASSERT(false && "unknown projector type"); GGML_ASSERT(false && "unknown projector type");
} }
@ -3004,6 +3063,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
{ {
n_patches += 2; // for BOI and EOI token embeddings n_patches += 2; // for BOI and EOI token embeddings
} break; } break;
case PROJECTOR_TYPE_LFM2A:
{
n_patches = ((((img->nx + 1) / 2) + 1) / 2 + 1) / 2;
} break;
default: default:
GGML_ABORT("unsupported projector type"); GGML_ABORT("unsupported projector type");
} }
@ -3362,6 +3425,27 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} }
set_input_i32("pos_w", pos_data); set_input_i32("pos_w", pos_data);
} break; } break;
case PROJECTOR_TYPE_LFM2A:
{
GGML_ASSERT(imgs.entries.size() == 1);
const auto n_frames = clip_n_output_tokens(ctx, imgs.entries.front().get());
auto d_model = 512;
auto seq_len = n_frames * 2 - 1;
std::vector<float> pos_emb(d_model*seq_len);
std::vector<double> inv_freq(d_model / 2);
for (size_t i = 0; i < inv_freq.size(); ++i) {
inv_freq[i] = std::exp(-(std::log(10000.0) / (float)d_model) * (2.0f * (float)(i)));
}
for (int64_t pos = 0; pos < seq_len; ++pos) {
for (size_t i = 0; i < inv_freq.size(); ++i) {
const float ang = (n_frames - pos - 1) * inv_freq[i];
pos_emb[pos*d_model + 2*i + 0] = sinf(ang); // even
pos_emb[pos*d_model + 2*i + 1] = cosf(ang); // odd
}
}
set_input_f32("pos_emb", pos_emb);
} break;
default: default:
GGML_ABORT("Unknown projector type"); GGML_ABORT("Unknown projector type");
} }
@ -3456,6 +3540,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_2_w->ne[1]; return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_COGVLM: case PROJECTOR_TYPE_COGVLM:
return ctx->model.mm_4h_to_h_w->ne[1]; return ctx->model.mm_4h_to_h_w->ne[1];
case PROJECTOR_TYPE_LFM2A:
return ctx->model.position_embeddings->ne[0];
case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_GLM4V:
return ctx->model.mm_ffn_down_w->ne[1]; return ctx->model.mm_ffn_down_w->ne[1];
default: default:

View File

@ -0,0 +1,217 @@
#include "models.h"
ggml_cgraph * clip_graph_conformer::build() {
const int n_frames = img.nx;
const int n_pos = n_frames / 2;
const int n_pos_embd = (((((n_frames + 1) / 2) + 1) / 2 + 1) / 2) * 2 - 1;
GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos);
ggml_tensor * pos_emb = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 512, n_pos_embd);
ggml_set_name(pos_emb, "pos_emb");
ggml_set_input(pos_emb);
ggml_build_forward_expand(gf, pos_emb);
ggml_tensor * inp = build_inp_raw(1);
cb(inp, "input", -1);
auto * cur = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
// pre encode, conv subsampling
{
// layer.0 - conv2d
cur = ggml_conv_2d(ctx0, model.pre_encode_conv_X_w[0], cur, 2, 2, 1, 1, 1, 1);
cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[0]);
cb(cur, "conformer.pre_encode.conv.{}", 0);
// layer.1 - relu
cur = ggml_relu_inplace(ctx0, cur);
// layer.2 conv2d dw
cur = ggml_conv_2d_dw_direct(ctx0, model.pre_encode_conv_X_w[2], cur, 2, 2, 1, 1, 1, 1);
cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[2]);
cb(cur, "conformer.pre_encode.conv.{}", 2);
// layer.3 conv2d
cur = ggml_conv_2d_direct(ctx0, model.pre_encode_conv_X_w[3], cur, 1, 1, 0, 0, 1, 1);
cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[3]);
cb(cur, "conformer.pre_encode.conv.{}", 3);
// layer.4 - relu
cur = ggml_relu_inplace(ctx0, cur);
// layer.5 conv2d dw
cur = ggml_conv_2d_dw_direct(ctx0, model.pre_encode_conv_X_w[5], cur, 2, 2, 1, 1, 1, 1);
cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[5]);
cb(cur, "conformer.pre_encode.conv.{}", 5);
// layer.6 conv2d
cur = ggml_conv_2d_direct(ctx0, model.pre_encode_conv_X_w[6], cur, 1, 1, 0, 0, 1, 1);
cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[6]);
cb(cur, "conformer.pre_encode.conv.{}", 6);
// layer.7 - relu
cur = ggml_relu_inplace(ctx0, cur);
// flatten channel and frequency axis
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]);
// calculate out
cur = ggml_mul_mat(ctx0, model.pre_encode_out_w, cur);
cur = ggml_add(ctx0, cur, model.pre_encode_out_b);
cb(cur, "conformer.pre_encode.out", -1);
}
// pos_emb
cb(pos_emb, "pos_emb", -1);
for (int il = 0; il < hparams.n_layer; il++) {
const auto & layer = model.layers[il];
auto * residual = cur;
cb(cur, "layer.in", il);
// feed_forward1
cur = build_norm(cur, layer.ff_norm_w, layer.ff_norm_b, NORM_TYPE_NORMAL, 1e-5, il);
cb(cur, "conformer.layers.{}.norm_feed_forward1", il);
cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w, layer.ff_down_b, FFN_SILU,
il);
cb(cur, "conformer.layers.{}.feed_forward1.linear2", il);
const auto fc_factor = 0.5f;
residual = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor));
// self-attention
{
cur = build_norm(residual, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, 1e-5, il);
cb(cur, "conformer.layers.{}.norm_self_att", il);
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, Qcur->ne[1]);
ggml_tensor * Q_bias_u = ggml_add(ctx0, Qcur, layer.pos_bias_u);
Q_bias_u = ggml_permute(ctx0, Q_bias_u, 0, 2, 1, 3);
ggml_tensor * Q_bias_v = ggml_add(ctx0, Qcur, layer.pos_bias_v);
Q_bias_v = ggml_permute(ctx0, Q_bias_v, 0, 2, 1, 3);
// TODO @ngxson : some cont can/should be removed when ggml_mul_mat support these cases
ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, Kcur->ne[1]);
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, Vcur->ne[1]);
Vcur = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3));
// build_attn won't fit due to matrix_ac and matrix_bd separation
ggml_tensor * matrix_ac = ggml_mul_mat(ctx0, Q_bias_u, Kcur);
matrix_ac = ggml_cont(ctx0, ggml_permute(ctx0, matrix_ac, 1, 0, 2, 3));
cb(matrix_ac, "conformer.layers.{}.self_attn.id3", il);
auto * p = ggml_mul_mat(ctx0, layer.linear_pos_w, pos_emb);
cb(p, "conformer.layers.{}.self_attn.linear_pos", il);
p = ggml_reshape_3d(ctx0, p, d_head, n_head, p->ne[1]);
p = ggml_permute(ctx0, p, 0, 2, 1, 3);
auto * matrix_bd = ggml_mul_mat(ctx0, Q_bias_v, p);
matrix_bd = ggml_cont(ctx0, ggml_permute(ctx0, matrix_bd, 1, 0, 2, 3));
// rel shift
{
const auto pos_len = matrix_bd->ne[0];
const auto q_len = matrix_bd->ne[1];
const auto h = matrix_bd->ne[2];
matrix_bd = ggml_pad(ctx0, matrix_bd, 1, 0, 0, 0);
matrix_bd = ggml_roll(ctx0, matrix_bd, 1, 0, 0, 0);
matrix_bd = ggml_reshape_3d(ctx0, matrix_bd, q_len, pos_len + 1, h);
matrix_bd = ggml_view_3d(ctx0, matrix_bd, q_len, pos_len, h, matrix_bd->nb[1],
matrix_bd->nb[2], matrix_bd->nb[0] * q_len);
matrix_bd = ggml_cont_3d(ctx0, matrix_bd, pos_len, q_len, h);
}
matrix_bd = ggml_view_3d(ctx0, matrix_bd, matrix_ac->ne[0], matrix_bd->ne[1],
matrix_bd->ne[2], matrix_bd->nb[1], matrix_bd->nb[2], 0);
auto * scores = ggml_add(ctx0, matrix_ac, matrix_bd);
scores = ggml_scale(ctx0, scores, 1.0f / std::sqrt(d_head));
cb(scores, "conformer.layers.{}.self_attn.id0", il);
ggml_tensor * attn = ggml_soft_max(ctx0, scores);
ggml_tensor * x = ggml_mul_mat(ctx0, attn, Vcur);
x = ggml_permute(ctx0, x, 2, 0, 1, 3);
x = ggml_cont_2d(ctx0, x, x->ne[0] * x->ne[1], x->ne[2]);
ggml_tensor * out = ggml_mul_mat(ctx0, layer.o_w, x);
out = ggml_add(ctx0, out, layer.o_b);
cb(out, "conformer.layers.{}.self_attn.linear_out", il);
cur = out;
}
residual = ggml_add(ctx0, residual, cur);
cur = build_norm(residual, layer.norm_conv_w, layer.norm_conv_b, NORM_TYPE_NORMAL, 1e-5, il);
cb(cur, "conformer.layers.{}.norm_conv", il);
// conv
{
auto * x = cur;
x = ggml_mul_mat(ctx0, layer.conv_pw1_w, x);
x = ggml_add(ctx0, x, layer.conv_pw1_b);
cb(x, "conformer.layers.{}.conv.pointwise_conv1", il);
// ggml_glu doesn't support sigmoid
// TODO @ngxson : support this ops in ggml
{
int64_t d = x->ne[0] / 2;
ggml_tensor * gate = ggml_sigmoid(ctx0, ggml_view_2d(ctx0, x, d, x->ne[1], x->nb[1], d * x->nb[0]));
x = ggml_mul(ctx0, ggml_view_2d(ctx0, x, d, x->ne[1], x->nb[1], 0), gate);
x = ggml_cont(ctx0, ggml_transpose(ctx0, x));
}
// use ggml_ssm_conv for f32 precision
x = ggml_pad(ctx0, x, 4, 0, 0, 0);
x = ggml_roll(ctx0, x, 4, 0, 0, 0);
x = ggml_pad(ctx0, x, 4, 0, 0, 0);
x = ggml_ssm_conv(ctx0, x, layer.conv_dw_w);
x = ggml_add(ctx0, x, layer.conv_dw_b);
x = ggml_add(ctx0, ggml_mul(ctx0, x, layer.conv_norm_w), layer.conv_norm_b);
x = ggml_silu(ctx0, x);
// pointwise_conv2
x = ggml_mul_mat(ctx0, layer.conv_pw2_w, x);
x = ggml_add(ctx0, x, layer.conv_pw2_b);
cur = x;
}
residual = ggml_add(ctx0, residual, cur);
cur = build_norm(residual, layer.ff_norm_1_w, layer.ff_norm_1_b, NORM_TYPE_NORMAL, 1e-5, il);
cb(cur, "conformer.layers.{}.norm_feed_forward2", il);
cur = build_ffn(cur, layer.ff_up_1_w, layer.ff_up_1_b, nullptr, nullptr, layer.ff_down_1_w, layer.ff_down_1_b,
FFN_SILU, il); // TODO(tarek): read activation for ffn from hparams
cb(cur, "conformer.layers.{}.feed_forward2.linear2", il);
residual = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor));
cb(residual, "conformer.layers.{}.conv.id", il);
cur = build_norm(residual, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, 1e-5, il);
cb(cur, "conformer.layers.{}.norm_out", il);
}
// audio adapter
cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1);
cb(cur, "audio_adapter.model.{}", 0);
cur = build_ffn(cur, model.mm_1_w, model.mm_1_b, nullptr, nullptr, model.mm_3_w, model.mm_3_b, FFN_GELU_ERF, -1);
cb(cur, "projected", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}

View File

@ -57,6 +57,11 @@ struct clip_graph_whisper_enc : clip_graph {
ggml_cgraph * build() override; ggml_cgraph * build() override;
}; };
struct clip_graph_conformer : clip_graph {
clip_graph_conformer(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
};
struct clip_graph_glm4v : clip_graph { struct clip_graph_glm4v : clip_graph {
clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override; ggml_cgraph * build() override;

View File

@ -535,3 +535,56 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
return true; return true;
} }
//
// mtmd_audio_preprocessor_conformer
//
void mtmd_audio_preprocessor_conformer::initialize() {
g_cache.fill_sin_cos_table(hparams.audio_n_fft);
g_cache.fill_hann_window(hparams.audio_window_len, true);
g_cache.fill_mel_filterbank_matrix(
hparams.n_mel_bins,
hparams.audio_n_fft,
hparams.audio_sample_rate);
}
bool mtmd_audio_preprocessor_conformer::preprocess(
const float * samples,
size_t n_samples,
std::vector<mtmd_audio_mel> & output) {
// empty audio
if (n_samples == 0) {
return false;
}
filter_params params;
params.n_mel = hparams.n_mel_bins;
params.n_fft_bins = 1 + (hparams.audio_n_fft / 2);
params.hann_window_size = hparams.audio_window_len;
params.hop_length = hparams.audio_hop_len;
params.sample_rate = hparams.audio_sample_rate;
params.center_padding = true;
params.preemph = 0.97f;
params.use_natural_log = true;
params.norm_per_feature = true;
// make sure the global cache is initialized
GGML_ASSERT(!g_cache.sin_vals.empty());
GGML_ASSERT(!g_cache.cos_vals.empty());
GGML_ASSERT(!g_cache.filters.data.empty());
mtmd_audio_mel out_full;
bool ok = log_mel_spectrogram(
samples,
n_samples,
4, // n_threads
params,
out_full);
if (!ok) {
return false;
}
output.push_back(std::move(out_full));
return true;
}

View File

@ -32,3 +32,9 @@ struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor {
void initialize() override; void initialize() override;
bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override; bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override;
}; };
struct mtmd_audio_preprocessor_conformer : mtmd_audio_preprocessor {
mtmd_audio_preprocessor_conformer(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {}
void initialize() override;
bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override;
};

View File

@ -309,9 +309,24 @@ int main(int argc, char ** argv) {
if (g_is_interrupted) return 130; if (g_is_interrupted) return 130;
auto eval_system_prompt_if_present = [&] {
if (params.system_prompt.empty()) {
return 0;
}
common_chat_msg msg;
msg.role = "system";
msg.content = params.system_prompt;
return eval_message(ctx, msg);
};
LOG_WRN("WARN: This is an experimental CLI for testing multimodal capability.\n"); LOG_WRN("WARN: This is an experimental CLI for testing multimodal capability.\n");
LOG_WRN(" For normal use cases, please use the standard llama-cli\n"); LOG_WRN(" For normal use cases, please use the standard llama-cli\n");
if (eval_system_prompt_if_present()) {
return 1;
}
if (is_single_turn) { if (is_single_turn) {
g_is_generating = true; g_is_generating = true;
if (params.prompt.find(mtmd_default_marker()) == std::string::npos) { if (params.prompt.find(mtmd_default_marker()) == std::string::npos) {
@ -321,6 +336,7 @@ int main(int argc, char ** argv) {
params.prompt = mtmd_default_marker() + params.prompt; params.prompt = mtmd_default_marker() + params.prompt;
} }
} }
common_chat_msg msg; common_chat_msg msg;
msg.role = "user"; msg.role = "user";
msg.content = params.prompt; msg.content = params.prompt;
@ -369,6 +385,9 @@ int main(int argc, char ** argv) {
ctx.n_past = 0; ctx.n_past = 0;
ctx.chat_history.clear(); ctx.chat_history.clear();
llama_memory_clear(llama_get_memory(ctx.lctx), true); llama_memory_clear(llama_get_memory(ctx.lctx), true);
if (eval_system_prompt_if_present()) {
return 1;
}
LOG("Chat history cleared\n\n"); LOG("Chat history cleared\n\n");
continue; continue;
} }

View File

@ -332,6 +332,9 @@ struct mtmd_context {
case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_GLMA:
audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a); audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a);
break; break;
case PROJECTOR_TYPE_LFM2A:
audio_preproc = std::make_unique<mtmd_audio_preprocessor_conformer>(ctx_a);
break;
default: default:
GGML_ABORT("unsupported audio projector type"); GGML_ABORT("unsupported audio projector type");
} }

View File

@ -84,6 +84,7 @@ add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0"
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M" add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M"
add_test_audio "ggml-org/LFM2-Audio-1.5B-GGUF:Q8_0"
# to test the big models, run: ./tests.sh big # to test the big models, run: ./tests.sh big
if [ "$RUN_BIG_TESTS" = true ]; then if [ "$RUN_BIG_TESTS" = true ]; then

Binary file not shown.

View File

@ -2109,9 +2109,9 @@
} }
}, },
"node_modules/@sveltejs/kit": { "node_modules/@sveltejs/kit": {
"version": "2.48.5", "version": "2.49.2",
"resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.5.tgz", "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.49.2.tgz",
"integrity": "sha512-/rnwfSWS3qwUSzvHynUTORF9xSJi7PCR9yXkxUOnRrNqyKmCmh3FPHH+E9BbgqxXfTevGXBqgnlh9kMb+9T5XA==", "integrity": "sha512-Vp3zX/qlwerQmHMP6x0Ry1oY7eKKRcOWGc2P59srOp4zcqyn+etJyQpELgOi4+ZSUgteX8Y387NuwruLgGXLUQ==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
@ -5797,9 +5797,9 @@
} }
}, },
"node_modules/mdast-util-to-hast": { "node_modules/mdast-util-to-hast": {
"version": "13.2.0", "version": "13.2.1",
"resolved": "https://registry.npmjs.org/mdast-util-to-hast/-/mdast-util-to-hast-13.2.0.tgz", "resolved": "https://registry.npmjs.org/mdast-util-to-hast/-/mdast-util-to-hast-13.2.1.tgz",
"integrity": "sha512-QGYKEuUsYT9ykKBCMOEDLsU5JRObWQusAolFMeko/tYPufNkRffBAQjIE+99jbA87xv6FgmjLtwjh9wBWajwAA==", "integrity": "sha512-cctsq2wp5vTsLIcaymblUriiTcZd0CwWtCbLvrOzYCDZoWyMNV8sZ7krj09FSnsiJi3WVsHLM4k6Dq/yaPyCXA==",
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
"@types/hast": "^3.0.0", "@types/hast": "^3.0.0",

View File

@ -124,3 +124,10 @@ declare global {
SettingsConfigType SettingsConfigType
}; };
} }
declare global {
interface Window {
idxThemeStyle?: number;
idxCodeBlock?: number;
}
}

View File

@ -244,7 +244,7 @@
<div class="info my-6 grid gap-4"> <div class="info my-6 grid gap-4">
{#if displayedModel()} {#if displayedModel()}
<span class="inline-flex flex-wrap items-center gap-2 text-xs text-muted-foreground"> <div class="inline-flex flex-wrap items-start gap-2 text-xs text-muted-foreground">
{#if isRouter} {#if isRouter}
<ModelsSelector <ModelsSelector
currentModel={displayedModel()} currentModel={displayedModel()}
@ -258,11 +258,13 @@
{#if currentConfig.showMessageStats && message.timings && message.timings.predicted_n && message.timings.predicted_ms} {#if currentConfig.showMessageStats && message.timings && message.timings.predicted_n && message.timings.predicted_ms}
<ChatMessageStatistics <ChatMessageStatistics
promptTokens={message.timings.prompt_n}
promptMs={message.timings.prompt_ms}
predictedTokens={message.timings.predicted_n} predictedTokens={message.timings.predicted_n}
predictedMs={message.timings.predicted_ms} predictedMs={message.timings.predicted_ms}
/> />
{/if} {/if}
</span> </div>
{/if} {/if}
{#if config().showToolCalls} {#if config().showToolCalls}

View File

@ -1,20 +1,122 @@
<script lang="ts"> <script lang="ts">
import { Clock, Gauge, WholeWord } from '@lucide/svelte'; import { Clock, Gauge, WholeWord, BookOpenText, Sparkles } from '@lucide/svelte';
import { BadgeChatStatistic } from '$lib/components/app'; import { BadgeChatStatistic } from '$lib/components/app';
import * as Tooltip from '$lib/components/ui/tooltip';
import { ChatMessageStatsView } from '$lib/enums';
interface Props { interface Props {
predictedTokens: number; predictedTokens: number;
predictedMs: number; predictedMs: number;
promptTokens?: number;
promptMs?: number;
} }
let { predictedTokens, predictedMs }: Props = $props(); let { predictedTokens, predictedMs, promptTokens, promptMs }: Props = $props();
let activeView: ChatMessageStatsView = $state(ChatMessageStatsView.GENERATION);
let tokensPerSecond = $derived((predictedTokens / predictedMs) * 1000); let tokensPerSecond = $derived((predictedTokens / predictedMs) * 1000);
let timeInSeconds = $derived((predictedMs / 1000).toFixed(2)); let timeInSeconds = $derived((predictedMs / 1000).toFixed(2));
let promptTokensPerSecond = $derived(
promptTokens !== undefined && promptMs !== undefined
? (promptTokens / promptMs) * 1000
: undefined
);
let promptTimeInSeconds = $derived(
promptMs !== undefined ? (promptMs / 1000).toFixed(2) : undefined
);
let hasPromptStats = $derived(
promptTokens !== undefined &&
promptMs !== undefined &&
promptTokensPerSecond !== undefined &&
promptTimeInSeconds !== undefined
);
</script> </script>
<BadgeChatStatistic icon={WholeWord} value="{predictedTokens} tokens" /> <div class="inline-flex items-center text-xs text-muted-foreground">
<div class="inline-flex items-center rounded-sm bg-muted-foreground/15 p-0.5">
{#if hasPromptStats}
<Tooltip.Root>
<Tooltip.Trigger>
<button
type="button"
class="inline-flex h-5 w-5 items-center justify-center rounded-sm transition-colors {activeView ===
ChatMessageStatsView.READING
? 'bg-background text-foreground shadow-sm'
: 'hover:text-foreground'}"
onclick={() => (activeView = ChatMessageStatsView.READING)}
>
<BookOpenText class="h-3 w-3" />
<span class="sr-only">Reading</span>
</button>
</Tooltip.Trigger>
<Tooltip.Content>
<p>Reading (prompt processing)</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
<Tooltip.Root>
<Tooltip.Trigger>
<button
type="button"
class="inline-flex h-5 w-5 items-center justify-center rounded-sm transition-colors {activeView ===
ChatMessageStatsView.GENERATION
? 'bg-background text-foreground shadow-sm'
: 'hover:text-foreground'}"
onclick={() => (activeView = ChatMessageStatsView.GENERATION)}
>
<Sparkles class="h-3 w-3" />
<span class="sr-only">Generation</span>
</button>
</Tooltip.Trigger>
<Tooltip.Content>
<p>Generation (token output)</p>
</Tooltip.Content>
</Tooltip.Root>
</div>
<BadgeChatStatistic icon={Clock} value="{timeInSeconds}s" /> <div class="flex items-center gap-1 px-2">
{#if activeView === ChatMessageStatsView.GENERATION}
<BadgeChatStatistic icon={Gauge} value="{tokensPerSecond.toFixed(2)} tokens/s" /> <BadgeChatStatistic
class="bg-transparent"
icon={WholeWord}
value="{predictedTokens} tokens"
tooltipLabel="Generated tokens"
/>
<BadgeChatStatistic
class="bg-transparent"
icon={Clock}
value="{timeInSeconds}s"
tooltipLabel="Generation time"
/>
<BadgeChatStatistic
class="bg-transparent"
icon={Gauge}
value="{tokensPerSecond.toFixed(2)} tokens/s"
tooltipLabel="Generation speed"
/>
{:else if hasPromptStats}
<BadgeChatStatistic
class="bg-transparent"
icon={WholeWord}
value="{promptTokens} tokens"
tooltipLabel="Prompt tokens"
/>
<BadgeChatStatistic
class="bg-transparent"
icon={Clock}
value="{promptTimeInSeconds}s"
tooltipLabel="Prompt processing time"
/>
<BadgeChatStatistic
class="bg-transparent"
icon={Gauge}
value="{promptTokensPerSecond!.toFixed(2)} tokens/s"
tooltipLabel="Prompt processing speed"
/>
{/if}
</div>
</div>

View File

@ -587,7 +587,7 @@
&::after { &::after {
content: ''; content: '';
position: fixed; position: absolute;
bottom: 0; bottom: 0;
z-index: -1; z-index: -1;
left: 0; left: 0;

View File

@ -1,5 +1,6 @@
<script lang="ts"> <script lang="ts">
import { BadgeInfo } from '$lib/components/app'; import { BadgeInfo } from '$lib/components/app';
import * as Tooltip from '$lib/components/ui/tooltip';
import { copyToClipboard } from '$lib/utils'; import { copyToClipboard } from '$lib/utils';
import type { Component } from 'svelte'; import type { Component } from 'svelte';
@ -7,19 +8,37 @@
class?: string; class?: string;
icon: Component; icon: Component;
value: string | number; value: string | number;
tooltipLabel?: string;
} }
let { class: className = '', icon: Icon, value }: Props = $props(); let { class: className = '', icon: Icon, value, tooltipLabel }: Props = $props();
function handleClick() { function handleClick() {
void copyToClipboard(String(value)); void copyToClipboard(String(value));
} }
</script> </script>
<BadgeInfo class={className} onclick={handleClick}> {#if tooltipLabel}
{#snippet icon()} <Tooltip.Root>
<Icon class="h-3 w-3" /> <Tooltip.Trigger>
{/snippet} <BadgeInfo class={className} onclick={handleClick}>
{#snippet icon()}
<Icon class="h-3 w-3" />
{/snippet}
{value} {value}
</BadgeInfo> </BadgeInfo>
</Tooltip.Trigger>
<Tooltip.Content>
<p>{tooltipLabel}</p>
</Tooltip.Content>
</Tooltip.Root>
{:else}
<BadgeInfo class={className} onclick={handleClick}>
{#snippet icon()}
<Icon class="h-3 w-3" />
{/snippet}
{value}
</BadgeInfo>
{/if}

View File

@ -7,15 +7,19 @@
import remarkRehype from 'remark-rehype'; import remarkRehype from 'remark-rehype';
import rehypeKatex from 'rehype-katex'; import rehypeKatex from 'rehype-katex';
import rehypeStringify from 'rehype-stringify'; import rehypeStringify from 'rehype-stringify';
import { copyCodeToClipboard, preprocessLaTeX } from '$lib/utils'; import type { Root as HastRoot, RootContent as HastRootContent } from 'hast';
import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer'; import type { Root as MdastRoot } from 'mdast';
import { browser } from '$app/environment'; import { browser } from '$app/environment';
import { onDestroy, tick } from 'svelte';
import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer';
import { rehypeEnhanceLinks } from '$lib/markdown/enhance-links';
import { rehypeEnhanceCodeBlocks } from '$lib/markdown/enhance-code-blocks';
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
import { copyCodeToClipboard, preprocessLaTeX } from '$lib/utils';
import '$styles/katex-custom.scss'; import '$styles/katex-custom.scss';
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline'; import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
import githubLightCss from 'highlight.js/styles/github.css?inline'; import githubLightCss from 'highlight.js/styles/github.css?inline';
import { mode } from 'mode-watcher'; import { mode } from 'mode-watcher';
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
import CodePreviewDialog from './CodePreviewDialog.svelte'; import CodePreviewDialog from './CodePreviewDialog.svelte';
interface Props { interface Props {
@ -23,33 +27,24 @@
class?: string; class?: string;
} }
interface MarkdownBlock {
id: string;
html: string;
}
let { content, class: className = '' }: Props = $props(); let { content, class: className = '' }: Props = $props();
let containerRef = $state<HTMLDivElement>(); let containerRef = $state<HTMLDivElement>();
let processedHtml = $state(''); let renderedBlocks = $state<MarkdownBlock[]>([]);
let unstableBlockHtml = $state('');
let previewDialogOpen = $state(false); let previewDialogOpen = $state(false);
let previewCode = $state(''); let previewCode = $state('');
let previewLanguage = $state('text'); let previewLanguage = $state('text');
function loadHighlightTheme(isDark: boolean) { let pendingMarkdown: string | null = null;
if (!browser) return; let isProcessing = false;
const existingThemes = document.querySelectorAll('style[data-highlight-theme]'); const themeStyleId = `highlight-theme-${(window.idxThemeStyle = (window.idxThemeStyle ?? 0) + 1)}`;
existingThemes.forEach((style) => style.remove());
const style = document.createElement('style');
style.setAttribute('data-highlight-theme', 'true');
style.textContent = isDark ? githubDarkCss : githubLightCss;
document.head.appendChild(style);
}
$effect(() => {
const currentMode = mode.current;
const isDark = currentMode === 'dark';
loadHighlightTheme(isDark);
});
let processor = $derived(() => { let processor = $derived(() => {
return remark() return remark()
@ -61,139 +56,64 @@
.use(rehypeKatex) // Render math using KaTeX .use(rehypeKatex) // Render math using KaTeX
.use(rehypeHighlight) // Add syntax highlighting .use(rehypeHighlight) // Add syntax highlighting
.use(rehypeRestoreTableHtml) // Restore limited HTML (e.g., <br>, <ul>) inside Markdown tables .use(rehypeRestoreTableHtml) // Restore limited HTML (e.g., <br>, <ul>) inside Markdown tables
.use(rehypeStringify); // Convert to HTML string .use(rehypeEnhanceLinks) // Add target="_blank" to links
.use(rehypeEnhanceCodeBlocks) // Wrap code blocks with header and actions
.use(rehypeStringify, { allowDangerousHtml: true }); // Convert to HTML string
}); });
function enhanceLinks(html: string): string { /**
if (!html.includes('<a')) { * Removes click event listeners from copy and preview buttons.
return html; * Called on component destroy.
*/
function cleanupEventListeners() {
if (!containerRef) return;
const copyButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
const previewButtons = containerRef.querySelectorAll<HTMLButtonElement>('.preview-code-btn');
for (const button of copyButtons) {
button.removeEventListener('click', handleCopyClick);
} }
const tempDiv = document.createElement('div'); for (const button of previewButtons) {
tempDiv.innerHTML = html; button.removeEventListener('click', handlePreviewClick);
// Make all links open in new tabs
const linkElements = tempDiv.querySelectorAll('a[href]');
let mutated = false;
for (const link of linkElements) {
const target = link.getAttribute('target');
const rel = link.getAttribute('rel');
if (target !== '_blank' || rel !== 'noopener noreferrer') {
mutated = true;
}
link.setAttribute('target', '_blank');
link.setAttribute('rel', 'noopener noreferrer');
}
return mutated ? tempDiv.innerHTML : html;
}
function enhanceCodeBlocks(html: string): string {
if (!html.includes('<pre')) {
return html;
}
const tempDiv = document.createElement('div');
tempDiv.innerHTML = html;
const preElements = tempDiv.querySelectorAll('pre');
let mutated = false;
for (const [index, pre] of Array.from(preElements).entries()) {
const codeElement = pre.querySelector('code');
if (!codeElement) {
continue;
}
mutated = true;
let language = 'text';
const classList = Array.from(codeElement.classList);
for (const className of classList) {
if (className.startsWith('language-')) {
language = className.replace('language-', '');
break;
}
}
const rawCode = codeElement.textContent || '';
const codeId = `code-${Date.now()}-${index}`;
codeElement.setAttribute('data-code-id', codeId);
codeElement.setAttribute('data-raw-code', rawCode);
const wrapper = document.createElement('div');
wrapper.className = 'code-block-wrapper';
const header = document.createElement('div');
header.className = 'code-block-header';
const languageLabel = document.createElement('span');
languageLabel.className = 'code-language';
languageLabel.textContent = language;
const copyButton = document.createElement('button');
copyButton.className = 'copy-code-btn';
copyButton.setAttribute('data-code-id', codeId);
copyButton.setAttribute('title', 'Copy code');
copyButton.setAttribute('type', 'button');
copyButton.innerHTML = `
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-copy-icon lucide-copy"><rect width="14" height="14" x="8" y="8" rx="2" ry="2"/><path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/></svg>
`;
const actions = document.createElement('div');
actions.className = 'code-block-actions';
actions.appendChild(copyButton);
if (language.toLowerCase() === 'html') {
const previewButton = document.createElement('button');
previewButton.className = 'preview-code-btn';
previewButton.setAttribute('data-code-id', codeId);
previewButton.setAttribute('title', 'Preview code');
previewButton.setAttribute('type', 'button');
previewButton.innerHTML = `
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-eye lucide-eye-icon"><path d="M2.062 12.345a1 1 0 0 1 0-.69C3.5 7.73 7.36 5 12 5s8.5 2.73 9.938 6.655a1 1 0 0 1 0 .69C20.5 16.27 16.64 19 12 19s-8.5-2.73-9.938-6.655"/><circle cx="12" cy="12" r="3"/></svg>
`;
actions.appendChild(previewButton);
}
header.appendChild(languageLabel);
header.appendChild(actions);
wrapper.appendChild(header);
const clonedPre = pre.cloneNode(true) as HTMLElement;
wrapper.appendChild(clonedPre);
pre.parentNode?.replaceChild(wrapper, pre);
}
return mutated ? tempDiv.innerHTML : html;
}
async function processMarkdown(text: string): Promise<string> {
try {
let normalized = preprocessLaTeX(text);
const result = await processor().process(normalized);
const html = String(result);
const enhancedLinks = enhanceLinks(html);
return enhanceCodeBlocks(enhancedLinks);
} catch (error) {
console.error('Markdown processing error:', error);
// Fallback to plain text with line breaks
return text.replace(/\n/g, '<br>');
} }
} }
/**
* Removes this component's highlight.js theme style from the document head.
* Called on component destroy to clean up injected styles.
*/
function cleanupHighlightTheme() {
if (!browser) return;
const existingTheme = document.getElementById(themeStyleId);
existingTheme?.remove();
}
/**
* Loads the appropriate highlight.js theme based on dark/light mode.
* Injects a scoped style element into the document head.
* @param isDark - Whether to load the dark theme (true) or light theme (false)
*/
function loadHighlightTheme(isDark: boolean) {
if (!browser) return;
const existingTheme = document.getElementById(themeStyleId);
existingTheme?.remove();
const style = document.createElement('style');
style.id = themeStyleId;
style.textContent = isDark ? githubDarkCss : githubLightCss;
document.head.appendChild(style);
}
/**
* Extracts code information from a button click target within a code block.
* @param target - The clicked button element
* @returns Object with rawCode and language, or null if extraction fails
*/
function getCodeInfoFromTarget(target: HTMLElement) { function getCodeInfoFromTarget(target: HTMLElement) {
const wrapper = target.closest('.code-block-wrapper'); const wrapper = target.closest('.code-block-wrapper');
@ -209,12 +129,7 @@
return null; return null;
} }
const rawCode = codeElement.getAttribute('data-raw-code'); const rawCode = codeElement.textContent ?? '';
if (rawCode === null) {
console.error('No raw code found');
return null;
}
const languageLabel = wrapper.querySelector<HTMLElement>('.code-language'); const languageLabel = wrapper.querySelector<HTMLElement>('.code-language');
const language = languageLabel?.textContent?.trim() || 'text'; const language = languageLabel?.textContent?.trim() || 'text';
@ -222,6 +137,28 @@
return { rawCode, language }; return { rawCode, language };
} }
/**
* Generates a unique identifier for a HAST node based on its position.
* Used for stable block identification during incremental rendering.
* @param node - The HAST root content node
* @param indexFallback - Fallback index if position is unavailable
* @returns Unique string identifier for the node
*/
function getHastNodeId(node: HastRootContent, indexFallback: number): string {
const position = node.position;
if (position?.start?.offset != null && position?.end?.offset != null) {
return `hast-${position.start.offset}-${position.end.offset}`;
}
return `${node.type}-${indexFallback}`;
}
/**
* Handles click events on copy buttons within code blocks.
* Copies the raw code content to the clipboard.
* @param event - The click event from the copy button
*/
async function handleCopyClick(event: Event) { async function handleCopyClick(event: Event) {
event.preventDefault(); event.preventDefault();
event.stopPropagation(); event.stopPropagation();
@ -245,6 +182,25 @@
} }
} }
/**
* Handles preview dialog open state changes.
* Clears preview content when dialog is closed.
* @param open - Whether the dialog is being opened or closed
*/
function handlePreviewDialogOpenChange(open: boolean) {
previewDialogOpen = open;
if (!open) {
previewCode = '';
previewLanguage = 'text';
}
}
/**
* Handles click events on preview buttons within HTML code blocks.
* Opens a preview dialog with the rendered HTML content.
* @param event - The click event from the preview button
*/
function handlePreviewClick(event: Event) { function handlePreviewClick(event: Event) {
event.preventDefault(); event.preventDefault();
event.stopPropagation(); event.stopPropagation();
@ -266,6 +222,61 @@
previewDialogOpen = true; previewDialogOpen = true;
} }
/**
* Processes markdown content into stable and unstable HTML blocks.
* Uses incremental rendering: stable blocks are cached, unstable block is re-rendered.
* @param markdown - The raw markdown string to process
*/
async function processMarkdown(markdown: string) {
if (!markdown) {
renderedBlocks = [];
unstableBlockHtml = '';
return;
}
const normalized = preprocessLaTeX(markdown);
const processorInstance = processor();
const ast = processorInstance.parse(normalized) as MdastRoot;
const processedRoot = (await processorInstance.run(ast)) as HastRoot;
const processedChildren = processedRoot.children ?? [];
const stableCount = Math.max(processedChildren.length - 1, 0);
const nextBlocks: MarkdownBlock[] = [];
for (let index = 0; index < stableCount; index++) {
const hastChild = processedChildren[index];
const id = getHastNodeId(hastChild, index);
const existing = renderedBlocks[index];
if (existing && existing.id === id) {
nextBlocks.push(existing);
continue;
}
const html = stringifyProcessedNode(
processorInstance,
processedRoot,
processedChildren[index]
);
nextBlocks.push({ id, html });
}
let unstableHtml = '';
if (processedChildren.length > stableCount) {
const unstableChild = processedChildren[stableCount];
unstableHtml = stringifyProcessedNode(processorInstance, processedRoot, unstableChild);
}
renderedBlocks = nextBlocks;
await tick(); // Force DOM sync before updating unstable HTML block
unstableBlockHtml = unstableHtml;
}
/**
* Attaches click event listeners to copy and preview buttons in code blocks.
* Uses data-listener-bound attribute to prevent duplicate bindings.
*/
function setupCodeBlockActions() { function setupCodeBlockActions() {
if (!containerRef) return; if (!containerRef) return;
@ -287,40 +298,97 @@
} }
} }
function handlePreviewDialogOpenChange(open: boolean) { /**
previewDialogOpen = open; * Converts a single HAST node to an enhanced HTML string.
* Applies link and code block enhancements to the output.
* @param processorInstance - The remark/rehype processor instance
* @param processedRoot - The full processed HAST root (for context)
* @param child - The specific HAST child node to stringify
* @returns Enhanced HTML string representation of the node
*/
function stringifyProcessedNode(
processorInstance: ReturnType<typeof processor>,
processedRoot: HastRoot,
child: unknown
) {
const root: HastRoot = {
...(processedRoot as HastRoot),
children: [child as never]
};
if (!open) { return processorInstance.stringify(root);
previewCode = ''; }
previewLanguage = 'text';
/**
* Queues markdown for processing with coalescing support.
* Only processes the latest markdown when multiple updates arrive quickly.
* @param markdown - The markdown content to render
*/
async function updateRenderedBlocks(markdown: string) {
pendingMarkdown = markdown;
if (isProcessing) {
return;
}
isProcessing = true;
try {
while (pendingMarkdown !== null) {
const nextMarkdown = pendingMarkdown;
pendingMarkdown = null;
await processMarkdown(nextMarkdown);
}
} catch (error) {
console.error('Failed to process markdown:', error);
renderedBlocks = [];
unstableBlockHtml = markdown.replace(/\n/g, '<br>');
} finally {
isProcessing = false;
} }
} }
$effect(() => { $effect(() => {
if (content) { const currentMode = mode.current;
processMarkdown(content) const isDark = currentMode === 'dark';
.then((result) => {
processedHtml = result; loadHighlightTheme(isDark);
})
.catch((error) => {
console.error('Failed to process markdown:', error);
processedHtml = content.replace(/\n/g, '<br>');
});
} else {
processedHtml = '';
}
}); });
$effect(() => { $effect(() => {
if (containerRef && processedHtml) { updateRenderedBlocks(content);
});
$effect(() => {
const hasRenderedBlocks = renderedBlocks.length > 0;
const hasUnstableBlock = Boolean(unstableBlockHtml);
if ((hasRenderedBlocks || hasUnstableBlock) && containerRef) {
setupCodeBlockActions(); setupCodeBlockActions();
} }
}); });
onDestroy(() => {
cleanupEventListeners();
cleanupHighlightTheme();
});
</script> </script>
<div bind:this={containerRef} class={className}> <div bind:this={containerRef} class={className}>
<!-- eslint-disable-next-line no-at-html-tags --> {#each renderedBlocks as block (block.id)}
{@html processedHtml} <div class="markdown-block" data-block-id={block.id}>
<!-- eslint-disable-next-line no-at-html-tags -->
{@html block.html}
</div>
{/each}
{#if unstableBlockHtml}
<div class="markdown-block markdown-block--unstable" data-block-id="unstable">
<!-- eslint-disable-next-line no-at-html-tags -->
{@html unstableBlockHtml}
</div>
{/if}
</div> </div>
<CodePreviewDialog <CodePreviewDialog
@ -331,6 +399,11 @@
/> />
<style> <style>
.markdown-block,
.markdown-block--unstable {
display: contents;
}
/* Base typography styles */ /* Base typography styles */
div :global(p:not(:last-child)) { div :global(p:not(:last-child)) {
margin-bottom: 1rem; margin-bottom: 1rem;

View File

@ -0,0 +1,4 @@
export enum ChatMessageStatsView {
GENERATION = 'generation',
READING = 'reading'
}

View File

@ -1,5 +1,7 @@
export { AttachmentType } from './attachment'; export { AttachmentType } from './attachment';
export { ChatMessageStatsView } from './chat';
export { export {
FileTypeCategory, FileTypeCategory,
FileTypeImage, FileTypeImage,

View File

@ -0,0 +1,162 @@
/**
* Rehype plugin to enhance code blocks with wrapper, header, and action buttons.
*
* Wraps <pre><code> elements with a container that includes:
* - Language label
* - Copy button
* - Preview button (for HTML code blocks)
*
* This operates directly on the HAST tree for better performance,
* avoiding the need to stringify and re-parse HTML.
*/
import type { Plugin } from 'unified';
import type { Root, Element, ElementContent } from 'hast';
import { visit } from 'unist-util-visit';
declare global {
interface Window {
idxCodeBlock?: number;
}
}
const COPY_ICON_SVG = `<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-copy-icon lucide-copy"><rect width="14" height="14" x="8" y="8" rx="2" ry="2"/><path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/></svg>`;
const PREVIEW_ICON_SVG = `<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-eye lucide-eye-icon"><path d="M2.062 12.345a1 1 0 0 1 0-.69C3.5 7.73 7.36 5 12 5s8.5 2.73 9.938 6.655a1 1 0 0 1 0 .69C20.5 16.27 16.64 19 12 19s-8.5-2.73-9.938-6.655"/><circle cx="12" cy="12" r="3"/></svg>`;
/**
* Creates an SVG element node from raw SVG string.
* Since we can't parse HTML in HAST directly, we use the raw property.
*/
function createRawHtmlElement(html: string): Element {
return {
type: 'element',
tagName: 'span',
properties: {},
children: [{ type: 'raw', value: html } as unknown as ElementContent]
};
}
function createCopyButton(codeId: string): Element {
return {
type: 'element',
tagName: 'button',
properties: {
className: ['copy-code-btn'],
'data-code-id': codeId,
title: 'Copy code',
type: 'button'
},
children: [createRawHtmlElement(COPY_ICON_SVG)]
};
}
function createPreviewButton(codeId: string): Element {
return {
type: 'element',
tagName: 'button',
properties: {
className: ['preview-code-btn'],
'data-code-id': codeId,
title: 'Preview code',
type: 'button'
},
children: [createRawHtmlElement(PREVIEW_ICON_SVG)]
};
}
function createHeader(language: string, codeId: string): Element {
const actions: Element[] = [createCopyButton(codeId)];
if (language.toLowerCase() === 'html') {
actions.push(createPreviewButton(codeId));
}
return {
type: 'element',
tagName: 'div',
properties: { className: ['code-block-header'] },
children: [
{
type: 'element',
tagName: 'span',
properties: { className: ['code-language'] },
children: [{ type: 'text', value: language }]
},
{
type: 'element',
tagName: 'div',
properties: { className: ['code-block-actions'] },
children: actions
}
]
};
}
function createWrapper(header: Element, preElement: Element): Element {
return {
type: 'element',
tagName: 'div',
properties: { className: ['code-block-wrapper'] },
children: [header, preElement]
};
}
function extractLanguage(codeElement: Element): string {
const className = codeElement.properties?.className;
if (!Array.isArray(className)) return 'text';
for (const cls of className) {
if (typeof cls === 'string' && cls.startsWith('language-')) {
return cls.replace('language-', '');
}
}
return 'text';
}
/**
* Generates a unique code block ID using a global counter.
*/
function generateCodeId(): string {
if (typeof window !== 'undefined') {
return `code-${(window.idxCodeBlock = (window.idxCodeBlock ?? 0) + 1)}`;
}
// Fallback for SSR - use timestamp + random
return `code-${Date.now()}-${Math.random().toString(36).slice(2, 7)}`;
}
/**
* Rehype plugin to enhance code blocks with wrapper, header, and action buttons.
* This plugin wraps <pre><code> elements with a container that includes:
* - Language label
* - Copy button
* - Preview button (for HTML code blocks)
*/
export const rehypeEnhanceCodeBlocks: Plugin<[], Root> = () => {
return (tree: Root) => {
visit(tree, 'element', (node: Element, index, parent) => {
if (node.tagName !== 'pre' || !parent || index === undefined) return;
const codeElement = node.children.find(
(child): child is Element => child.type === 'element' && child.tagName === 'code'
);
if (!codeElement) return;
const language = extractLanguage(codeElement);
const codeId = generateCodeId();
codeElement.properties = {
...codeElement.properties,
'data-code-id': codeId
};
const header = createHeader(language, codeId);
const wrapper = createWrapper(header, node);
// Replace pre with wrapper in parent
(parent.children as ElementContent[])[index] = wrapper;
});
};
};

View File

@ -0,0 +1,33 @@
/**
* Rehype plugin to enhance links with security attributes.
*
* Adds target="_blank" and rel="noopener noreferrer" to all anchor elements,
* ensuring external links open in new tabs safely.
*/
import type { Plugin } from 'unified';
import type { Root, Element } from 'hast';
import { visit } from 'unist-util-visit';
/**
* Rehype plugin that adds security attributes to all links.
* This plugin ensures external links open in new tabs safely by adding:
* - target="_blank"
* - rel="noopener noreferrer"
*/
export const rehypeEnhanceLinks: Plugin<[], Root> = () => {
return (tree: Root) => {
visit(tree, 'element', (node: Element) => {
if (node.tagName !== 'a') return;
const props = node.properties ?? {};
// Only modify if href exists
if (!props.href) return;
props.target = '_blank';
props.rel = 'noopener noreferrer';
node.properties = props;
});
};
};

View File

@ -171,6 +171,7 @@ class ChatStore {
updateProcessingStateFromTimings( updateProcessingStateFromTimings(
timingData: { timingData: {
prompt_n: number; prompt_n: number;
prompt_ms?: number;
predicted_n: number; predicted_n: number;
predicted_per_second: number; predicted_per_second: number;
cache_n: number; cache_n: number;
@ -212,6 +213,7 @@ class ChatStore {
if (message.role === 'assistant' && message.timings) { if (message.role === 'assistant' && message.timings) {
const restoredState = this.parseTimingData({ const restoredState = this.parseTimingData({
prompt_n: message.timings.prompt_n || 0, prompt_n: message.timings.prompt_n || 0,
prompt_ms: message.timings.prompt_ms,
predicted_n: message.timings.predicted_n || 0, predicted_n: message.timings.predicted_n || 0,
predicted_per_second: predicted_per_second:
message.timings.predicted_n && message.timings.predicted_ms message.timings.predicted_n && message.timings.predicted_ms
@ -282,6 +284,7 @@ class ChatStore {
private parseTimingData(timingData: Record<string, unknown>): ApiProcessingState | null { private parseTimingData(timingData: Record<string, unknown>): ApiProcessingState | null {
const promptTokens = (timingData.prompt_n as number) || 0; const promptTokens = (timingData.prompt_n as number) || 0;
const promptMs = (timingData.prompt_ms as number) || undefined;
const predictedTokens = (timingData.predicted_n as number) || 0; const predictedTokens = (timingData.predicted_n as number) || 0;
const tokensPerSecond = (timingData.predicted_per_second as number) || 0; const tokensPerSecond = (timingData.predicted_per_second as number) || 0;
const cacheTokens = (timingData.cache_n as number) || 0; const cacheTokens = (timingData.cache_n as number) || 0;
@ -320,6 +323,7 @@ class ChatStore {
speculative: false, speculative: false,
progressPercent, progressPercent,
promptTokens, promptTokens,
promptMs,
cacheTokens cacheTokens
}; };
} }
@ -536,6 +540,7 @@ class ChatStore {
this.updateProcessingStateFromTimings( this.updateProcessingStateFromTimings(
{ {
prompt_n: timings?.prompt_n || 0, prompt_n: timings?.prompt_n || 0,
prompt_ms: timings?.prompt_ms,
predicted_n: timings?.predicted_n || 0, predicted_n: timings?.predicted_n || 0,
predicted_per_second: tokensPerSecond, predicted_per_second: tokensPerSecond,
cache_n: timings?.cache_n || 0, cache_n: timings?.cache_n || 0,
@ -768,10 +773,11 @@ class ChatStore {
content: streamingState.response content: streamingState.response
}; };
if (lastMessage.thinking?.trim()) updateData.thinking = lastMessage.thinking; if (lastMessage.thinking?.trim()) updateData.thinking = lastMessage.thinking;
const lastKnownState = this.getCurrentProcessingStateSync(); const lastKnownState = this.getProcessingState(conversationId);
if (lastKnownState) { if (lastKnownState) {
updateData.timings = { updateData.timings = {
prompt_n: lastKnownState.promptTokens || 0, prompt_n: lastKnownState.promptTokens || 0,
prompt_ms: lastKnownState.promptMs,
predicted_n: lastKnownState.tokensDecoded || 0, predicted_n: lastKnownState.tokensDecoded || 0,
cache_n: lastKnownState.cacheTokens || 0, cache_n: lastKnownState.cacheTokens || 0,
predicted_ms: predicted_ms:
@ -1253,6 +1259,7 @@ class ChatStore {
this.updateProcessingStateFromTimings( this.updateProcessingStateFromTimings(
{ {
prompt_n: timings?.prompt_n || 0, prompt_n: timings?.prompt_n || 0,
prompt_ms: timings?.prompt_ms,
predicted_n: timings?.predicted_n || 0, predicted_n: timings?.predicted_n || 0,
predicted_per_second: tokensPerSecond, predicted_per_second: tokensPerSecond,
cache_n: timings?.cache_n || 0, cache_n: timings?.cache_n || 0,

View File

@ -345,6 +345,7 @@ export interface ApiProcessingState {
// Progress information from prompt_progress // Progress information from prompt_progress
progressPercent?: number; progressPercent?: number;
promptTokens?: number; promptTokens?: number;
promptMs?: number;
cacheTokens?: number; cacheTokens?: number;
} }