Merge remote-tracking branch 'upstream/master' into gpu-sampling

This commit is contained in:
Daniel Bevenius 2025-11-25 08:20:50 +01:00
commit 53dca56d9b
6 changed files with 213 additions and 5 deletions

View File

@ -2501,9 +2501,11 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
if (hsv >= 192) { if (hsv >= 192) {
return 2; return 2;
} else if ((hsv | hsk) & 8) {
return 4;
} else { } else {
return 8; return 8;
} }
@ -2535,9 +2537,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
if ((hsv | hsk) & 8) { if ((hsv | hsk) & 8) {
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
return {get_fa_scalar_num_large_rows(hsv), 64}; return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
} else { } else {
return {get_fa_scalar_num_large_rows(hsv), 32}; return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
} }
} }
} }
@ -7740,7 +7742,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
// Needs to be kept up to date on shader changes // Needs to be kept up to date on shader changes
GGML_UNUSED(hsv); GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size; const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = get_fa_scalar_num_large_rows(hsv); const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t tmpsh = wg_size * sizeof(float); const uint32_t tmpsh = wg_size * sizeof(float);
@ -7871,7 +7873,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
case FA_SCALAR: case FA_SCALAR:
case FA_COOPMAT1: case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both // We may switch from coopmat1 to scalar, so use the scalar limit for both
max_gqa = get_fa_scalar_num_large_rows(HSV); max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
break; break;
case FA_COOPMAT2: case FA_COOPMAT2:
max_gqa = get_fa_num_small_rows(FA_COOPMAT2); max_gqa = get_fa_num_small_rows(FA_COOPMAT2);

View File

@ -7859,6 +7859,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
} }
} }
// Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012
test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
for (int kv : { 4096, 8192, 16384, }) { for (int kv : { 4096, 8192, 16384, }) {
for (int hs : { 64, 128, }) { for (int hs : { 64, 128, }) {
for (int nr : { 1, 4, }) { for (int nr : { 1, 4, }) {

Binary file not shown.

View File

@ -8,6 +8,7 @@
import rehypeKatex from 'rehype-katex'; import rehypeKatex from 'rehype-katex';
import rehypeStringify from 'rehype-stringify'; import rehypeStringify from 'rehype-stringify';
import { copyCodeToClipboard } from '$lib/utils/copy'; import { copyCodeToClipboard } from '$lib/utils/copy';
import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer';
import { preprocessLaTeX } from '$lib/utils/latex-protection'; import { preprocessLaTeX } from '$lib/utils/latex-protection';
import { browser } from '$app/environment'; import { browser } from '$app/environment';
import '$styles/katex-custom.scss'; import '$styles/katex-custom.scss';
@ -60,6 +61,7 @@
.use(remarkRehype) // Convert Markdown AST to rehype .use(remarkRehype) // Convert Markdown AST to rehype
.use(rehypeKatex) // Render math using KaTeX .use(rehypeKatex) // Render math using KaTeX
.use(rehypeHighlight) // Add syntax highlighting .use(rehypeHighlight) // Add syntax highlighting
.use(rehypeRestoreTableHtml) // Restore limited HTML (e.g., <br>, <ul>) inside Markdown tables
.use(rehypeStringify); // Convert to HTML string .use(rehypeStringify); // Convert to HTML string
}); });

View File

@ -0,0 +1,20 @@
/**
* Matches <br>, <br/>, <br /> tags (case-insensitive).
* Used to detect line breaks in table cell text content.
*/
export const BR_PATTERN = /<br\s*\/?\s*>/gi;
/**
* Matches a complete <ul>...</ul> block.
* Captures the inner content (group 1) for further <li> extraction.
* Case-insensitive, allows multiline content.
*/
export const LIST_PATTERN = /^<ul>([\s\S]*)<\/ul>$/i;
/**
* Matches individual <li>...</li> elements within a list.
* Captures the inner content (group 1) of each list item.
* Non-greedy to handle multiple consecutive items.
* Case-insensitive, allows multiline content.
*/
export const LI_PATTERN = /<li>([\s\S]*?)<\/li>/gi;

View File

@ -0,0 +1,181 @@
/**
* Rehype plugin to restore limited HTML elements inside Markdown table cells.
*
* ## Problem
* The remark/rehype pipeline neutralizes inline HTML as literal text
* (remarkLiteralHtml) so that XML/HTML snippets in LLM responses display
* as-is instead of being rendered. This causes <br> and <ul> markup in
* table cells to show as plain text.
*
* ## Solution
* This plugin traverses the HAST post-conversion, parses whitelisted HTML
* patterns from text nodes, and replaces them with actual HAST element nodes
* that will be rendered as real HTML.
*
* ## Supported HTML
* - `<br>` / `<br/>` / `<br />` - Line breaks (inline)
* - `<ul><li>...</li></ul>` - Unordered lists (block)
*
* ## Key Implementation Details
*
* ### 1. Sibling Combination (Critical)
* The Markdown pipeline may fragment content across multiple text nodes and `<br>`
* elements. For example, `<ul><li>a</li></ul>` might arrive as:
* - Text: `"<ul>"`
* - Element: `<br>`
* - Text: `"<li>a</li></ul>"`
*
* We must combine consecutive text nodes and `<br>` elements into a single string
* before attempting to parse list markup. Without this, list detection fails.
*
* ### 2. visitParents for Deep Traversal
* Table cell content may be wrapped in intermediate elements (e.g., `<p>` tags).
* Using `visitParents` instead of direct child iteration ensures we find text
* nodes at any depth within the cell.
*
* ### 3. Reference Comparison for No-Op Detection
* When checking if `<br>` expansion changed anything, we compare:
* `expanded.length !== 1 || expanded[0] !== textNode`
*
* This catches both cases:
* - Multiple nodes created (text was split)
* - Single NEW node created (original had only `<br>`, now it's an element)
*
* A simple `length > 1` check would miss the single `<br>` case.
*
* ### 4. Strict List Validation
* `parseList()` rejects malformed markup by checking for garbage text between
* `<li>` elements. This prevents creating broken DOM from partial matches like
* `<ul>garbage<li>a</li></ul>`.
*
* ### 5. Newline Substitution for `<br>` in Combined String
* When combining siblings, existing `<br>` elements become `\n` in the combined
* string. This allows list content to span visual lines while still being parsed
* as a single unit.
*
* @example
* // Input Markdown:
* // | Feature | Notes |
* // |---------|-------|
* // | Multi-line | First<br>Second |
* // | List | <ul><li>A</li><li>B</li></ul> |
* //
* // Without this plugin: <br> and <ul> render as literal text
* // With this plugin: <br> becomes line break, <ul> becomes actual list
*/
import type { Plugin } from 'unified';
import type { Element, ElementContent, Root, Text } from 'hast';
import { visit } from 'unist-util-visit';
import { visitParents } from 'unist-util-visit-parents';
import { BR_PATTERN, LIST_PATTERN, LI_PATTERN } from '$lib/constants/table-html-restorer';
/**
* Expands text containing `<br>` tags into an array of text nodes and br elements.
*/
function expandBrTags(value: string): ElementContent[] {
const matches = [...value.matchAll(BR_PATTERN)];
if (!matches.length) return [{ type: 'text', value } as Text];
const result: ElementContent[] = [];
let cursor = 0;
for (const m of matches) {
if (m.index! > cursor) {
result.push({ type: 'text', value: value.slice(cursor, m.index) } as Text);
}
result.push({ type: 'element', tagName: 'br', properties: {}, children: [] } as Element);
cursor = m.index! + m[0].length;
}
if (cursor < value.length) {
result.push({ type: 'text', value: value.slice(cursor) } as Text);
}
return result;
}
/**
* Parses a `<ul><li>...</li></ul>` string into a HAST element.
* Returns null if the markup is malformed or contains unexpected content.
*/
function parseList(value: string): Element | null {
const match = value.trim().match(LIST_PATTERN);
if (!match) return null;
const body = match[1];
const items: ElementContent[] = [];
let cursor = 0;
for (const liMatch of body.matchAll(LI_PATTERN)) {
// Reject if there's non-whitespace between list items
if (body.slice(cursor, liMatch.index!).trim()) return null;
items.push({
type: 'element',
tagName: 'li',
properties: {},
children: expandBrTags(liMatch[1] ?? '')
} as Element);
cursor = liMatch.index! + liMatch[0].length;
}
// Reject if no items found or trailing garbage exists
if (!items.length || body.slice(cursor).trim()) return null;
return { type: 'element', tagName: 'ul', properties: {}, children: items } as Element;
}
/**
* Processes a single table cell, restoring HTML elements from text content.
*/
function processCell(cell: Element) {
visitParents(cell, 'text', (textNode: Text, ancestors) => {
const parent = ancestors[ancestors.length - 1];
if (!parent || parent.type !== 'element') return;
const parentEl = parent as Element;
const siblings = parentEl.children as ElementContent[];
const startIndex = siblings.indexOf(textNode as ElementContent);
if (startIndex === -1) return;
// Combine consecutive text nodes and <br> elements into one string
let combined = '';
let endIndex = startIndex;
for (let i = startIndex; i < siblings.length; i++) {
const sib = siblings[i];
if (sib.type === 'text') {
combined += (sib as Text).value;
endIndex = i;
} else if (sib.type === 'element' && (sib as Element).tagName === 'br') {
combined += '\n';
endIndex = i;
} else {
break;
}
}
// Try parsing as list first (replaces entire combined range)
const list = parseList(combined);
if (list) {
siblings.splice(startIndex, endIndex - startIndex + 1, list);
return;
}
// Otherwise, just expand <br> tags in this text node
const expanded = expandBrTags(textNode.value);
if (expanded.length !== 1 || expanded[0] !== textNode) {
siblings.splice(startIndex, 1, ...expanded);
}
});
}
export const rehypeRestoreTableHtml: Plugin<[], Root> = () => (tree) => {
visit(tree, 'element', (node: Element) => {
if (node.tagName === 'td' || node.tagName === 'th') {
processCell(node);
}
});
};