fix: use tensor stride for fused QKV support in vaetki
This commit is contained in:
parent
c947c74a4c
commit
8bbeab0616
|
|
@ -7969,20 +7969,6 @@ class VaetkiVisionModel(MmprojModel):
|
||||||
if name.startswith("model.visual."):
|
if name.startswith("model.visual."):
|
||||||
name = name.replace("model.visual.", "visual.")
|
name = name.replace("model.visual.", "visual.")
|
||||||
|
|
||||||
# Split fused QKV tensors (build_vit fused QKV doesn't work for VAETKI)
|
|
||||||
if ".qkv." in name:
|
|
||||||
if data_torch.ndim == 2:
|
|
||||||
c3, _ = data_torch.shape
|
|
||||||
else:
|
|
||||||
c3 = data_torch.shape[0]
|
|
||||||
assert c3 % 3 == 0
|
|
||||||
c = c3 // 3
|
|
||||||
return [
|
|
||||||
(self.map_tensor_name(name.replace("qkv", "q")), data_torch[:c]),
|
|
||||||
(self.map_tensor_name(name.replace("qkv", "k")), data_torch[c:c * 2]),
|
|
||||||
(self.map_tensor_name(name.replace("qkv", "v")), data_torch[c * 2:]),
|
|
||||||
]
|
|
||||||
|
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,13 +33,14 @@ ggml_cgraph * clip_graph_vaetki::build() {
|
||||||
|
|
||||||
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) -> ggml_tensor * {
|
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) -> ggml_tensor * {
|
||||||
// split CLS and patch tokens
|
// split CLS and patch tokens
|
||||||
|
// use cur->nb[2] to support both fused QKV (nb[2]=3*n_embd) and separate Q/K/V (nb[2]=n_embd)
|
||||||
ggml_tensor * cur_cls = ggml_view_3d(ctx0, cur, d_head, n_head, 1,
|
ggml_tensor * cur_cls = ggml_view_3d(ctx0, cur, d_head, n_head, 1,
|
||||||
ggml_row_size(cur->type, d_head),
|
ggml_row_size(cur->type, d_head),
|
||||||
ggml_row_size(cur->type, d_head * n_head), 0);
|
cur->nb[2], 0);
|
||||||
ggml_tensor * cur_patch = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos_patches,
|
ggml_tensor * cur_patch = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos_patches,
|
||||||
ggml_row_size(cur->type, d_head),
|
ggml_row_size(cur->type, d_head),
|
||||||
ggml_row_size(cur->type, d_head * n_head),
|
cur->nb[2],
|
||||||
ggml_row_size(cur->type, d_head * n_head));
|
cur->nb[2]);
|
||||||
|
|
||||||
// apply RoPE to CLS token using class_pos_emb
|
// apply RoPE to CLS token using class_pos_emb
|
||||||
if (cls_cos && cls_sin) {
|
if (cls_cos && cls_sin) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue