clean up
This commit is contained in:
parent
b26b507c4e
commit
386ba479a2
|
|
@ -1579,15 +1579,7 @@ class MmprojModel(ModelBase):
|
||||||
|
|
||||||
# TODO @ngxson : this is a hack to support both vision and audio encoders
|
# TODO @ngxson : this is a hack to support both vision and audio encoders
|
||||||
have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder
|
have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder
|
||||||
self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True)
|
self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys)
|
||||||
# FIXME: DeepseekOCRVisionModel specific hack
|
|
||||||
if self.block_count is None:
|
|
||||||
if isinstance(self, DeepseekOCRVisionModel):
|
|
||||||
clip_block_count = self.hparams['layers']
|
|
||||||
if clip_block_count is not None:
|
|
||||||
self.block_count = clip_block_count
|
|
||||||
if self.block_count is None:
|
|
||||||
raise KeyError(f"could not find block count using any of: {self.n_block_keys}")
|
|
||||||
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
|
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
|
||||||
|
|
||||||
# load preprocessor config
|
# load preprocessor config
|
||||||
|
|
@ -6003,16 +5995,6 @@ class Gemma3VisionModel(MmprojModel):
|
||||||
|
|
||||||
@ModelBase.register("DeepseekOCRForCausalLM")
|
@ModelBase.register("DeepseekOCRForCausalLM")
|
||||||
class DeepseekOCRVisionModel(MmprojModel):
|
class DeepseekOCRVisionModel(MmprojModel):
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
proc_fname = self.dir_model / "processor_config.json"
|
|
||||||
|
|
||||||
if proc_fname.is_file():
|
|
||||||
with open(proc_fname, "r") as f:
|
|
||||||
self.preprocessor_config = json.load(f)
|
|
||||||
|
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
|
|
@ -6071,27 +6053,6 @@ class DeepseekOCRVisionModel(MmprojModel):
|
||||||
if ".attn.rel_pos_h" in name or ".attn.rel_pos_w" in name:
|
if ".attn.rel_pos_h" in name or ".attn.rel_pos_w" in name:
|
||||||
return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)]
|
return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)]
|
||||||
|
|
||||||
if name.startswith("model.vision_model.transformer.layers."):
|
|
||||||
# process visual tensors
|
|
||||||
# split QKV tensors if needed
|
|
||||||
if ".qkv_proj." in name:
|
|
||||||
if data_torch.ndim == 2: # weight
|
|
||||||
c3, _ = data_torch.shape
|
|
||||||
else: # bias
|
|
||||||
c3 = data_torch.shape[0]
|
|
||||||
assert c3 % 3 == 0
|
|
||||||
c = c3 // 3
|
|
||||||
wq = data_torch[:c]
|
|
||||||
wk = data_torch[c: c * 2]
|
|
||||||
wv = data_torch[c * 2:]
|
|
||||||
return [
|
|
||||||
(self.map_tensor_name(name.replace("qkv", "q")), wq),
|
|
||||||
(self.map_tensor_name(name.replace("qkv", "k")), wk),
|
|
||||||
(self.map_tensor_name(name.replace("qkv", "v")), wv),
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
|
||||||
|
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -7335,10 +7296,9 @@ class DeepseekV2Model(TextModel):
|
||||||
|
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
kv_lora_rank = hparams["q_lora_rank"] if hparams["q_lora_rank"] is not None else 512
|
kv_lora_rank = hparams["kv_lora_rank"] if hparams["kv_lora_rank"] is not None else 512
|
||||||
routed_scaling_factor = hparams.get("routed_scaling_factor", 1.0)
|
routed_scaling_factor = hparams.get("routed_scaling_factor", 1.0)
|
||||||
norm_topk_prob = hparams.get("norm_topk_prob", False)
|
norm_topk_prob = hparams.get("norm_topk_prob", False)
|
||||||
scoring_func = hparams.get("scoring_func", "softmax")
|
|
||||||
|
|
||||||
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
||||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||||
|
|
@ -7361,12 +7321,6 @@ class DeepseekV2Model(TextModel):
|
||||||
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
|
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
|
||||||
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
|
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
|
||||||
|
|
||||||
if scoring_func == "sigmoid":
|
|
||||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
|
||||||
elif scoring_func == "softmax":
|
|
||||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported scoring_func value: {scoring_func}")
|
|
||||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||||
|
|
||||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||||
|
|
|
||||||
|
|
@ -74,19 +74,19 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
|
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
|
||||||
LOG(" [\n");
|
LOG(" [\n");
|
||||||
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
|
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
|
||||||
if (i2 == n && ne[2] > 2*n) {
|
if (i2 == n && ne[2] > 2*n) {
|
||||||
LOG(" ..., \n");
|
LOG(" ..., \n");
|
||||||
i2 = ne[2] - n;
|
i2 = ne[2] - n;
|
||||||
}
|
}
|
||||||
LOG(" [\n");
|
LOG(" [\n");
|
||||||
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
|
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
|
||||||
if (i1 == n && ne[1] > 2*n) {
|
if (i1 == n && ne[1] > 2*n) {
|
||||||
LOG(" ..., \n");
|
LOG(" ..., \n");
|
||||||
i1 = ne[1] - n;
|
i1 = ne[1] - n;
|
||||||
}
|
}
|
||||||
LOG(" [");
|
LOG(" [");
|
||||||
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
|
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
|
||||||
if (i0 == n && ne[0] > 2*n) {
|
if (i0 == n && ne[0] > 2*n) {
|
||||||
LOG("..., ");
|
LOG("..., ");
|
||||||
|
|
@ -98,10 +98,10 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
|
||||||
}
|
}
|
||||||
LOG("],\n");
|
LOG("],\n");
|
||||||
}
|
}
|
||||||
LOG(" ],\n");
|
LOG(" ],\n");
|
||||||
}
|
}
|
||||||
LOG(" ]\n");
|
LOG(" ]\n");
|
||||||
LOG(" sum = %f\n", sum);
|
LOG(" sum = %f\n", sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: make this abort configurable/optional?
|
// TODO: make this abort configurable/optional?
|
||||||
|
|
@ -136,7 +136,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
|
||||||
snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str());
|
snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG("%s: %16s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,
|
LOG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,
|
||||||
t->name, ggml_type_name(t->type), ggml_op_desc(t),
|
t->name, ggml_type_name(t->type), ggml_op_desc(t),
|
||||||
src0->name, ggml_ne_string(src0).c_str(),
|
src0->name, ggml_ne_string(src0).c_str(),
|
||||||
src1 ? src1_str : "",
|
src1 ? src1_str : "",
|
||||||
|
|
|
||||||
|
|
@ -1127,12 +1127,12 @@ class GGUFWriter:
|
||||||
def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
|
def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
|
||||||
self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
|
self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
|
||||||
|
|
||||||
|
|
||||||
def add_vision_sam_layers_count(self, value: int) -> None:
|
def add_vision_sam_layers_count(self, value: int) -> None:
|
||||||
self.add_uint32(Keys.ClipVision.SAM.BLOCK_COUNT, value)
|
self.add_uint32(Keys.ClipVision.SAM.BLOCK_COUNT, value)
|
||||||
|
|
||||||
def add_vision_sam_embedding_length(self, value: int) -> None:
|
def add_vision_sam_embedding_length(self, value: int) -> None:
|
||||||
self.add_uint32(Keys.ClipVision.SAM.EMBEDDING_LENGTH, value)
|
self.add_uint32(Keys.ClipVision.SAM.EMBEDDING_LENGTH, value)
|
||||||
|
|
||||||
# audio models
|
# audio models
|
||||||
|
|
||||||
def add_audio_projection_dim(self, value: int) -> None:
|
def add_audio_projection_dim(self, value: int) -> None:
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from numpy.f2py.auxfuncs import throw_error
|
|
||||||
|
|
||||||
from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES
|
from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1242,11 +1239,11 @@ class TensorNameMap:
|
||||||
"visual.pos_embed", # qwen3vl
|
"visual.pos_embed", # qwen3vl
|
||||||
"model.vision.patch_embedding.position_embedding", # cogvlm
|
"model.vision.patch_embedding.position_embedding", # cogvlm
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_EMBD_IMGNL: (
|
MODEL_TENSOR.V_ENC_EMBD_IMGNL: (
|
||||||
"model.image_newline", # Deepseek-OCR
|
"model.image_newline", # Deepseek-OCR
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_EMBD_VSEP: (
|
MODEL_TENSOR.V_ENC_EMBD_VSEP: (
|
||||||
"model.view_seperator", # Deepseek-OCR
|
"model.view_seperator", # Deepseek-OCR
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -561,9 +561,9 @@ struct clip_graph {
|
||||||
hparams(model.hparams),
|
hparams(model.hparams),
|
||||||
img(img),
|
img(img),
|
||||||
patch_size(hparams.patch_size),
|
patch_size(hparams.patch_size),
|
||||||
n_patches_x(img.nx / patch_size), // sam 1024 / 16 = 64
|
n_patches_x(img.nx / patch_size),
|
||||||
n_patches_y(img.ny / patch_size), // sam 1024 / 16 = 64
|
n_patches_y(img.ny / patch_size),
|
||||||
n_patches(n_patches_x * n_patches_y), // sam 64 * 64 = 4096
|
n_patches(n_patches_x * n_patches_y),
|
||||||
n_embd(hparams.n_embd),
|
n_embd(hparams.n_embd),
|
||||||
n_head(hparams.n_head),
|
n_head(hparams.n_head),
|
||||||
d_head(n_embd / n_head),
|
d_head(n_embd / n_head),
|
||||||
|
|
@ -664,13 +664,13 @@ struct clip_graph {
|
||||||
ggml_tensor * inp_raw = build_inp_raw();
|
ggml_tensor * inp_raw = build_inp_raw();
|
||||||
ggml_tensor * sam_out = build_sam(inp_raw);
|
ggml_tensor * sam_out = build_sam(inp_raw);
|
||||||
ggml_tensor * clip_out = build_dsocr_clip(sam_out);
|
ggml_tensor * clip_out = build_dsocr_clip(sam_out);
|
||||||
|
|
||||||
int clip_n_patches = sam_out->ne[0] * sam_out->ne[1];
|
int clip_n_patches = sam_out->ne[0] * sam_out->ne[1];
|
||||||
|
|
||||||
sam_out = ggml_cont(ctx0, ggml_permute(ctx0, sam_out, 1, 2, 0, 3));
|
sam_out = ggml_cont(ctx0, ggml_permute(ctx0, sam_out, 1, 2, 0, 3));
|
||||||
sam_out = ggml_reshape_2d(ctx0, sam_out, sam_out->ne[0], clip_n_patches);
|
sam_out = ggml_reshape_2d(ctx0, sam_out, sam_out->ne[0], clip_n_patches);
|
||||||
clip_out = ggml_view_2d(ctx0, clip_out, n_embd, clip_n_patches, clip_out->nb[1], clip_out->nb[1]);
|
clip_out = ggml_view_2d(ctx0, clip_out, n_embd, clip_n_patches, clip_out->nb[1], clip_out->nb[1]);
|
||||||
|
|
||||||
ggml_tensor * cur;
|
ggml_tensor * cur;
|
||||||
cur = ggml_concat(ctx0, clip_out, sam_out, 0);
|
cur = ggml_concat(ctx0, clip_out, sam_out, 0);
|
||||||
cur = ggml_reshape_2d(ctx0, cur, 2*n_embd,clip_n_patches);
|
cur = ggml_reshape_2d(ctx0, cur, 2*n_embd,clip_n_patches);
|
||||||
|
|
@ -1302,7 +1302,7 @@ struct clip_graph {
|
||||||
norm_t,
|
norm_t,
|
||||||
hparams.ffn_op,
|
hparams.ffn_op,
|
||||||
model.position_embeddings,
|
model.position_embeddings,
|
||||||
nullptr); // shape [1024, 16, 16]
|
nullptr);
|
||||||
|
|
||||||
// remove CLS token
|
// remove CLS token
|
||||||
cur = ggml_view_2d(ctx0, cur,
|
cur = ggml_view_2d(ctx0, cur,
|
||||||
|
|
@ -2260,7 +2260,6 @@ private:
|
||||||
const int64_t C = rel_pos->ne[0]; // channels
|
const int64_t C = rel_pos->ne[0]; // channels
|
||||||
const int64_t L = rel_pos->ne[1]; // length
|
const int64_t L = rel_pos->ne[1]; // length
|
||||||
|
|
||||||
//GGML_ASSERT(2*std::max(q_size, k_size) - 1 == L);
|
|
||||||
|
|
||||||
const auto max_rel_dist = 2*std::max(q_size, k_size) - 1;
|
const auto max_rel_dist = 2*std::max(q_size, k_size) - 1;
|
||||||
ggml_tensor * rel_pos_resized = rel_pos;
|
ggml_tensor * rel_pos_resized = rel_pos;
|
||||||
|
|
@ -2399,18 +2398,15 @@ private:
|
||||||
// build the input after conv2d (inp_raw --> patches)
|
// build the input after conv2d (inp_raw --> patches)
|
||||||
// returns tensor with shape [n_embd, n_patches]
|
// returns tensor with shape [n_embd, n_patches]
|
||||||
ggml_tensor * build_inp() {
|
ggml_tensor * build_inp() {
|
||||||
// Image to Patch Embedding.
|
ggml_tensor * inp_raw = build_inp_raw();
|
||||||
ggml_tensor * inp_raw = build_inp_raw(); // sam shape = [1024, 1024, 3]
|
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||||
// sam patch_embeddings_0 shape = [768, 3, 16, 16]
|
inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
|
||||||
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); // sam shape = [64, 64, 768]
|
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
||||||
inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); // sam shape = [4096, 768]
|
|
||||||
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // sam shape = [768, 4096]
|
|
||||||
if (model.patch_bias) {
|
if (model.patch_bias) {
|
||||||
// sam patch_bias shape = [768]
|
|
||||||
inp = ggml_add(ctx0, inp, model.patch_bias);
|
inp = ggml_add(ctx0, inp, model.patch_bias);
|
||||||
cb(inp, "patch_bias", -1);
|
cb(inp, "patch_bias", -1);
|
||||||
}
|
}
|
||||||
return inp; // shape = [n_embd, n_patches] same as [768, 4096]
|
return inp;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * build_inp_raw(int channels = 3) {
|
ggml_tensor * build_inp_raw(int channels = 3) {
|
||||||
|
|
@ -2707,11 +2703,11 @@ private:
|
||||||
const int d_heads = n_embd / n_heads;
|
const int d_heads = n_embd / n_heads;
|
||||||
|
|
||||||
ggml_tensor * inpL;
|
ggml_tensor * inpL;
|
||||||
|
|
||||||
inpL = ggml_conv_2d_sk_p0(ctx0, model.patch_embed_proj_w, inp_raw);
|
inpL = ggml_conv_2d_sk_p0(ctx0, model.patch_embed_proj_w, inp_raw);
|
||||||
inpL = ggml_add(ctx0, inpL, ggml_reshape_3d(ctx0, model.patch_embed_proj_b, 1, 1, n_embd));
|
inpL = ggml_add(ctx0, inpL, ggml_reshape_3d(ctx0, model.patch_embed_proj_b, 1, 1, n_embd));
|
||||||
inpL = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 1, 2, 0, 3));
|
inpL = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 1, 2, 0, 3));
|
||||||
|
|
||||||
ggml_tensor * cur;
|
ggml_tensor * cur;
|
||||||
const auto tgt_size = inpL->ne[1];
|
const auto tgt_size = inpL->ne[1];
|
||||||
const auto str_size = model.pos_embed->ne[1];
|
const auto str_size = model.pos_embed->ne[1];
|
||||||
|
|
@ -2756,7 +2752,7 @@ private:
|
||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
const int B = cur->ne[3];
|
const int B = cur->ne[3];
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
||||||
cur = ggml_add(ctx0, cur, layer.qkv_b);
|
cur = ggml_add(ctx0, cur, layer.qkv_b);
|
||||||
cur = ggml_cont(ctx0, cur); // Ensure tensor is contiguous before reshape
|
cur = ggml_cont(ctx0, cur); // Ensure tensor is contiguous before reshape
|
||||||
|
|
@ -2836,7 +2832,7 @@ private:
|
||||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
|
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
|
||||||
cur = build_norm(cur, model.neck_1_w, model.neck_1_b, NORM_TYPE_NORMAL, hparams.eps, -1);
|
cur = build_norm(cur, model.neck_1_w, model.neck_1_b, NORM_TYPE_NORMAL, hparams.eps, -1);
|
||||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
|
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
|
||||||
|
|
||||||
cur = ggml_conv_2d(ctx0, model.neck_2_w, cur, 1, 1, 1, 1, 1, 1);
|
cur = ggml_conv_2d(ctx0, model.neck_2_w, cur, 1, 1, 1, 1, 1, 1);
|
||||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
|
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
|
||||||
cur = build_norm(cur, model.neck_3_w, model.neck_3_b, NORM_TYPE_NORMAL, hparams.eps, -1);
|
cur = build_norm(cur, model.neck_3_w, model.neck_3_b, NORM_TYPE_NORMAL, hparams.eps, -1);
|
||||||
|
|
@ -2866,7 +2862,7 @@ private:
|
||||||
if (tgt_size != src_size) {
|
if (tgt_size != src_size) {
|
||||||
ggml_tensor * old_pos_embd;
|
ggml_tensor * old_pos_embd;
|
||||||
ggml_tensor * cls_tok;
|
ggml_tensor * cls_tok;
|
||||||
|
|
||||||
old_pos_embd = ggml_view_2d(
|
old_pos_embd = ggml_view_2d(
|
||||||
ctx0, new_pos_embd,
|
ctx0, new_pos_embd,
|
||||||
new_pos_embd->ne[0], src_size * src_size,
|
new_pos_embd->ne[0], src_size * src_size,
|
||||||
|
|
@ -2895,7 +2891,7 @@ private:
|
||||||
ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32);
|
ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32);
|
||||||
ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions);
|
ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions);
|
||||||
|
|
||||||
ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, ffn_op_type::FFN_GELU_QUICK,
|
ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, ffn_op_type::FFN_GELU_QUICK,
|
||||||
learned_pos_embd, nullptr); // shape [1024, 16, 16]
|
learned_pos_embd, nullptr); // shape [1024, 16, 16]
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
@ -5174,11 +5170,11 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||||
const int orig_h = original_size.height;
|
const int orig_h = original_size.height;
|
||||||
const int orig_area = orig_h * orig_w;
|
const int orig_area = orig_h * orig_w;
|
||||||
std::array<uint8_t, 3u> color;
|
std::array<uint8_t, 3u> color;
|
||||||
|
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
color[i] = (int)(255 * params.image_mean[i]);
|
color[i] = (int)(255 * params.image_mean[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mode_i = 0;
|
int mode_i = 0;
|
||||||
int min_diff = orig_area;
|
int min_diff = orig_area;
|
||||||
|
|
||||||
|
|
@ -5193,7 +5189,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||||
if (mode_i < 2) {
|
if (mode_i < 2) {
|
||||||
/* Native Resolution (Tiny/Small) */
|
/* Native Resolution (Tiny/Small) */
|
||||||
const int image_size = native_resolutions[mode_i];
|
const int image_size = native_resolutions[mode_i];
|
||||||
|
|
||||||
// Just resize the image to image_size × image_size
|
// Just resize the image to image_size × image_size
|
||||||
clip_image_u8_ptr resized_img(clip_image_u8_init());
|
clip_image_u8_ptr resized_img(clip_image_u8_init());
|
||||||
img_tool::resize(*img, *resized_img,
|
img_tool::resize(*img, *resized_img,
|
||||||
|
|
@ -5210,7 +5206,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||||
else if (mode_i < 4) {
|
else if (mode_i < 4) {
|
||||||
/* Native Resolution (Base/Large) */
|
/* Native Resolution (Base/Large) */
|
||||||
const int image_size = native_resolutions[mode_i];
|
const int image_size = native_resolutions[mode_i];
|
||||||
|
|
||||||
// Resize maintaining aspect ratio, then pad to square
|
// Resize maintaining aspect ratio, then pad to square
|
||||||
float scale = std::min(
|
float scale = std::min(
|
||||||
static_cast<float>(image_size) / orig_w,
|
static_cast<float>(image_size) / orig_w,
|
||||||
|
|
@ -5267,7 +5263,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||||
else {
|
else {
|
||||||
GGML_ABORT("DeepSeek-OCR hasn't supported Gundam/Gundam-Master yet");
|
GGML_ABORT("DeepSeek-OCR hasn't supported Gundam/Gundam-Master yet");
|
||||||
/* Dynamic Resolution (Gundam/Gundam-Master) */
|
/* Dynamic Resolution (Gundam/Gundam-Master) */
|
||||||
|
|
||||||
// configurable, or read from params
|
// configurable, or read from params
|
||||||
const int min_num = 2;
|
const int min_num = 2;
|
||||||
const int max_num = 9;
|
const int max_num = 9;
|
||||||
|
|
@ -5276,10 +5272,10 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||||
// original image size
|
// original image size
|
||||||
const int orig_w = original_size.width;
|
const int orig_w = original_size.width;
|
||||||
const int orig_h = original_size.height;
|
const int orig_h = original_size.height;
|
||||||
|
|
||||||
// create overview image (thumbnail)
|
// create overview image (thumbnail)
|
||||||
clip_image_u8_ptr overview_img(clip_image_u8_init());
|
clip_image_u8_ptr overview_img(clip_image_u8_init());
|
||||||
img_tool::resize(*img, *overview_img, { image_size, image_size },
|
img_tool::resize(*img, *overview_img, { image_size, image_size },
|
||||||
img_tool::RESIZE_ALGO_BICUBIC_PILLOW, true, color);
|
img_tool::RESIZE_ALGO_BICUBIC_PILLOW, true, color);
|
||||||
clip_image_f32_ptr overview_f32(clip_image_f32_init());
|
clip_image_f32_ptr overview_f32(clip_image_f32_init());
|
||||||
normalize_image_u8_to_f32(*overview_img, *overview_f32, params.image_mean, params.image_std);
|
normalize_image_u8_to_f32(*overview_img, *overview_f32, params.image_mean, params.image_std);
|
||||||
|
|
@ -5287,7 +5283,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||||
|
|
||||||
// build candidate grids (cols, rows)
|
// build candidate grids (cols, rows)
|
||||||
auto target_ratios = ds_build_target_ratios(min_num, max_num);
|
auto target_ratios = ds_build_target_ratios(min_num, max_num);
|
||||||
|
|
||||||
// pick the grid that best matches the original aspect ratio
|
// pick the grid that best matches the original aspect ratio
|
||||||
const float aspect_ratio = static_cast<float>(orig_w) / static_cast<float>(orig_h);
|
const float aspect_ratio = static_cast<float>(orig_w) / static_cast<float>(orig_h);
|
||||||
auto best = ds_find_closest_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size);
|
auto best = ds_find_closest_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size);
|
||||||
|
|
@ -5296,7 +5292,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||||
|
|
||||||
// resize to refined size (no padding, direct resize)
|
// resize to refined size (no padding, direct resize)
|
||||||
clip_image_u8_ptr refined_img(clip_image_u8_init());
|
clip_image_u8_ptr refined_img(clip_image_u8_init());
|
||||||
img_tool::resize(*img, *refined_img, { image_size * grid_cols, image_size * grid_rows },
|
img_tool::resize(*img, *refined_img, { image_size * grid_cols, image_size * grid_rows },
|
||||||
img_tool::RESIZE_ALGO_BICUBIC_PILLOW, false);
|
img_tool::RESIZE_ALGO_BICUBIC_PILLOW, false);
|
||||||
|
|
||||||
// crop slices from the refined image
|
// crop slices from the refined image
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue