Merge pull request #7 from bluebread/sf/deepseek-ocr
First DeepSeek-OCR working implementation
This commit is contained in:
commit
6b0e7cd136
|
|
@ -1824,6 +1824,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.image_max_tokens = value;
|
||||
}
|
||||
).set_examples(mmproj_examples).set_env("LLAMA_ARG_IMAGE_MAX_TOKENS"));
|
||||
add_opt(common_arg(
|
||||
{"--dsocr-mode"}, "MODE",
|
||||
"DeepSeek-OCR resolution mode, one of:\n"
|
||||
"- auto (default): automatically select resolution\n"
|
||||
"- tiny, small, base, large: native resolution\n"
|
||||
"- gundam, gundam-master: dynamic resolution",
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (value == "auto" || value == "tiny" || value == "small" || value == "base" ||
|
||||
value == "large" || value == "gundam" || value == "gundam-master") {
|
||||
params.dsocr_mode = value;
|
||||
} else {
|
||||
throw std::invalid_argument("invalid value");
|
||||
}
|
||||
}
|
||||
).set_examples(mmproj_examples).set_env("LLAMA_ARG_DSOCR_MODE"));
|
||||
if (llama_supports_rpc()) {
|
||||
add_opt(common_arg(
|
||||
{"--rpc"}, "SERVERS",
|
||||
|
|
|
|||
|
|
@ -433,6 +433,7 @@ struct common_params {
|
|||
std::vector<std::string> image; // path to image file(s)
|
||||
int image_min_tokens = -1;
|
||||
int image_max_tokens = -1;
|
||||
std::string dsocr_mode = "auto"; // DeepSeek-OCR resolution mode: auto, tiny, small, base, large, gundam, gundam-master
|
||||
|
||||
// finetune
|
||||
struct lr_opt lr;
|
||||
|
|
|
|||
|
|
@ -6013,12 +6013,14 @@ class DeepseekOCRVisionModel(MmprojModel):
|
|||
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
# TODO: increase numercial stability. maybe delete later.
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
# related to https://github.com/ggml-org/llama.cpp/issues/13025
|
||||
if "input_projection" in name:
|
||||
return gguf.GGMLQuantizationType.F16
|
||||
if ".embeddings." in name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
# if "input_projection" in name:
|
||||
# return gguf.GGMLQuantizationType.F16
|
||||
# if ".embeddings." in name:
|
||||
# return gguf.GGMLQuantizationType.F32
|
||||
# return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Only process vision-related tensors, skip language model tensors
|
||||
|
|
|
|||
|
|
@ -214,5 +214,7 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
sf0, sf1, sf2, sf3, pixel_offset, stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5204,6 +5204,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
|||
GGML_ASSERT(q->ne[3] == v->ne[3]);
|
||||
|
||||
if (mask) {
|
||||
GGML_ASSERT(mask->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
||||
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include <climits>
|
||||
#include <cstdarg>
|
||||
#include <cinttypes>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
|
|
@ -442,6 +443,33 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
|
|||
// debugging
|
||||
//
|
||||
|
||||
|
||||
static std::string to_ne_string(const ggml_tensor * t) {
|
||||
std::string str;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||
str += std::to_string(t->ne[i]);
|
||||
if (i + 1 < GGML_MAX_DIMS) {
|
||||
str += ", ";
|
||||
}
|
||||
}
|
||||
return str;
|
||||
}
|
||||
|
||||
static void print_tensor_info(ggml_tensor * t) {
|
||||
const struct ggml_tensor * src0 = t->src[0];
|
||||
const struct ggml_tensor * src1 = t->src[1];
|
||||
|
||||
char src1_str[128] = {0};
|
||||
if (src1) {
|
||||
snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, to_ne_string(src1).c_str());
|
||||
}
|
||||
|
||||
printf("%s: %s = %s(%s{%s}, %s)\n",
|
||||
t->name, ggml_type_name(t->type), ggml_op_desc(t),
|
||||
src0->name, to_ne_string(src0).c_str(),
|
||||
src1 ? src1_str : "");
|
||||
}
|
||||
|
||||
static void print_tensor_shape(ggml_tensor * t) {
|
||||
printf("%s.shape = [", t->name);
|
||||
for (int i = 0; i < ggml_n_dims(t); ++i) {
|
||||
|
|
@ -453,12 +481,50 @@ static void print_tensor_shape(ggml_tensor * t) {
|
|||
printf("]\n");
|
||||
}
|
||||
|
||||
static void print_tensor_sum(ggml_tensor * t, uint8_t * data, int64_t n) {
|
||||
(void) n; // unused parameter
|
||||
ggml_type type = t->type;
|
||||
int64_t * ne = t->ne;
|
||||
size_t * nb = t->nb;
|
||||
double sum = 0.0;
|
||||
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
|
||||
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
|
||||
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
|
||||
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
|
||||
float v;
|
||||
if (type == GGML_TYPE_F16) {
|
||||
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
|
||||
} else if (type == GGML_TYPE_F32) {
|
||||
v = *(float *) &data[i];
|
||||
} else if (type == GGML_TYPE_I32) {
|
||||
v = (float) *(int32_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I16) {
|
||||
v = (float) *(int16_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I8) {
|
||||
v = (float) *(int8_t *) &data[i];
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
sum += v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
printf("%s.sum = %.6f\n", t->name, sum);
|
||||
}
|
||||
|
||||
static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) {
|
||||
ggml_type type = t->type;
|
||||
int64_t * ne = t->ne;
|
||||
size_t * nb = t->nb;
|
||||
printf("%s.data: [\n", t->name);
|
||||
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
|
||||
printf("%s.data: [\n", t->name);
|
||||
if (i3 == n && ne[3] > 2*n) {
|
||||
printf(" ..., \n");
|
||||
i3 = ne[3] - n;
|
||||
}
|
||||
printf(" [\n");
|
||||
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
|
||||
if (i2 == n && ne[2] > 2*n) {
|
||||
printf(" ..., \n");
|
||||
|
|
@ -500,6 +566,120 @@ static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) {
|
|||
}
|
||||
printf(" ]\n");
|
||||
}
|
||||
printf(" ]\n");
|
||||
}
|
||||
|
||||
static void save_tensor_to_file(const struct ggml_tensor * tensor, const uint8_t * data_ptr) {
|
||||
char filename[512];
|
||||
snprintf(filename, sizeof(filename), "%s_cpp.txt", tensor->name);
|
||||
|
||||
FILE * f = fopen(filename, "w");
|
||||
if (!f) {
|
||||
fprintf(stderr, "Failed to open %s\n", filename);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check tensor size and warn if too large
|
||||
int64_t total_elements = ggml_nelements(tensor);
|
||||
fprintf(stderr, "Saving tensor %s (%lld elements) to %s\n",
|
||||
tensor->name, (long long)total_elements, filename);
|
||||
|
||||
if (total_elements > 10000000) { // 10M elements
|
||||
fprintf(stderr, "Warning: tensor is very large (%lld elements), this may take time\n",
|
||||
(long long)total_elements);
|
||||
}
|
||||
|
||||
const uint8_t * data = (data_ptr) ? data_ptr : (uint8_t *) tensor->data;
|
||||
ggml_type type = tensor->type;
|
||||
const int64_t * ne = tensor->ne;
|
||||
const size_t * nb = tensor->nb;
|
||||
|
||||
// Use a buffer to reduce I/O calls
|
||||
const size_t BUF_SIZE = 8192;
|
||||
char * buf = (char *) malloc(BUF_SIZE);
|
||||
if (!buf) {
|
||||
fprintf(stderr, "Failed to allocate buffer\n");
|
||||
fclose(f);
|
||||
return;
|
||||
}
|
||||
size_t buf_pos = 0;
|
||||
|
||||
// Helper lambda to flush buffer
|
||||
auto flush_buf = [&]() {
|
||||
if (buf_pos > 0) {
|
||||
fwrite(buf, 1, buf_pos, f);
|
||||
buf_pos = 0;
|
||||
}
|
||||
};
|
||||
|
||||
// Helper to append to buffer
|
||||
auto append = [&](const char * str, size_t len) {
|
||||
if (buf_pos + len >= BUF_SIZE) {
|
||||
flush_buf();
|
||||
}
|
||||
if (len >= BUF_SIZE) {
|
||||
// String too large for buffer, write directly
|
||||
fwrite(str, 1, len, f);
|
||||
} else {
|
||||
memcpy(buf + buf_pos, str, len);
|
||||
buf_pos += len;
|
||||
}
|
||||
};
|
||||
|
||||
auto append_str = [&](const char * str) {
|
||||
append(str, strlen(str));
|
||||
};
|
||||
|
||||
char num_buf[32];
|
||||
|
||||
// Write header once for all batches
|
||||
append_str(tensor->name);
|
||||
append_str(".data: [\n");
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
|
||||
append_str(" [\n"); // Start of batch
|
||||
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
|
||||
append_str(" [\n");
|
||||
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
|
||||
append_str(" [");
|
||||
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
|
||||
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
|
||||
float v;
|
||||
if (type == GGML_TYPE_F16) {
|
||||
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
|
||||
} else if (type == GGML_TYPE_F32) {
|
||||
v = *(float *) &data[i];
|
||||
} else if (type == GGML_TYPE_I32) {
|
||||
v = (float) *(int32_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I16) {
|
||||
v = (float) *(int16_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I8) {
|
||||
v = (float) *(int8_t *) &data[i];
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
int len = snprintf(num_buf, sizeof(num_buf), "%8.4f", v);
|
||||
append(num_buf, len);
|
||||
if (i0 < ne[0] - 1) append_str(", ");
|
||||
}
|
||||
append_str("],\n");
|
||||
}
|
||||
append_str(" ],\n");
|
||||
}
|
||||
append_str(" ]"); // End of batch
|
||||
if (i3 < ne[3] - 1) {
|
||||
append_str(",\n"); // Comma between batches
|
||||
} else {
|
||||
append_str("\n");
|
||||
}
|
||||
}
|
||||
|
||||
append_str("]\n"); // Close the top-level array
|
||||
|
||||
flush_buf();
|
||||
free(buf);
|
||||
fclose(f);
|
||||
fprintf(stderr, "Tensor saved successfully\n");
|
||||
}
|
||||
|
||||
//
|
||||
|
|
|
|||
|
|
@ -193,8 +193,6 @@ struct clip_hparams {
|
|||
int32_t attn_window_size = 0;
|
||||
int32_t n_wa_pattern = 0;
|
||||
|
||||
bool crop_mode = false;
|
||||
|
||||
// audio
|
||||
int32_t n_mel_bins = 0; // whisper preprocessor
|
||||
int32_t proj_stack_factor = 0; // ultravox
|
||||
|
|
@ -208,6 +206,9 @@ struct clip_hparams {
|
|||
int32_t custom_image_min_tokens = -1;
|
||||
int32_t custom_image_max_tokens = -1;
|
||||
|
||||
// DeepSeek-OCR resolution mode
|
||||
enum clip_dsocr_mode dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_AUTO;
|
||||
|
||||
void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) {
|
||||
const int cur_merge = n_merge == 0 ? 1 : n_merge;
|
||||
const int patch_area = patch_size * patch_size * cur_merge * cur_merge;
|
||||
|
|
@ -512,6 +513,7 @@ struct clip_ctx {
|
|||
if (ctx_params.image_max_tokens > 0) {
|
||||
model.hparams.custom_image_max_tokens = ctx_params.image_max_tokens;
|
||||
}
|
||||
model.hparams.dsocr_mode = ctx_params.dsocr_mode;
|
||||
|
||||
backend_ptrs.push_back(backend_cpu);
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu));
|
||||
|
|
@ -680,16 +682,19 @@ struct clip_graph {
|
|||
const auto tgt_size = inpL->ne[1];
|
||||
const auto str_size = model.pos_embed->ne[1];
|
||||
if (str_size != tgt_size) {
|
||||
ggml_tensor * old_pos_embed = nullptr;
|
||||
old_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, model.pos_embed, 2, 0, 1, 3));
|
||||
// TODO: ggml_interpolate doesn't support bicubic model for CUDA backend
|
||||
ggml_tensor * new_pos_embed = ggml_interpolate(
|
||||
ctx0,
|
||||
model.pos_embed,
|
||||
old_pos_embed,
|
||||
tgt_size,
|
||||
tgt_size,
|
||||
enc_n_embd,
|
||||
1,
|
||||
ggml_scale_mode::GGML_SCALE_MODE_BICUBIC
|
||||
);
|
||||
new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 2,1,0,3));
|
||||
new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 1, 2, 0, 3));
|
||||
cur = ggml_add(ctx0, inpL, new_pos_embed);
|
||||
} else {
|
||||
cur = ggml_add(ctx0, inpL, model.pos_embed);
|
||||
|
|
@ -698,10 +703,10 @@ struct clip_graph {
|
|||
// loop over layers
|
||||
for (int il = 0; il < _depth; il++) {
|
||||
auto & layer = model.sam_layers[il];
|
||||
ggml_tensor * shortcut = cur;
|
||||
|
||||
// layernorm1
|
||||
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
|
||||
cb(cur, "enc_layer_inp_normed", il);
|
||||
|
||||
const int64_t w0 = cur->ne[1];
|
||||
const int64_t h0 = cur->ne[2];
|
||||
|
|
@ -710,7 +715,7 @@ struct clip_graph {
|
|||
// local attention layer - apply window partition
|
||||
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172
|
||||
//cur = ggml_win_part(ctx0, cur, 14);
|
||||
cur = window_partition(ctx0, cur, 14);
|
||||
cur = window_partition(ctx0, cur, 14); // TODO: make this configurable
|
||||
}
|
||||
|
||||
const int64_t W = cur->ne[1];
|
||||
|
|
@ -718,110 +723,93 @@ struct clip_graph {
|
|||
|
||||
// self-attention
|
||||
{
|
||||
const int B = cur->ne[3];
|
||||
|
||||
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
||||
cur = ggml_add(ctx0, cur, layer.qkv_b);
|
||||
const int B = cur->ne[3];
|
||||
cur = ggml_cont(ctx0, cur); // Ensure tensor is contiguous before reshape
|
||||
cur = ggml_reshape_4d(ctx0, cur, enc_n_embd, 3, W*H, B);
|
||||
|
||||
cur = ggml_reshape_4d(ctx0, cur, enc_n_embd, 3, W * H, B);
|
||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 3, 1, 2));
|
||||
ggml_tensor * Q;
|
||||
ggml_tensor * K;
|
||||
ggml_tensor * V;
|
||||
|
||||
ggml_tensor * Qcur =
|
||||
ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 0);
|
||||
Qcur = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, enc_n_heads, W * H, B);
|
||||
Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, enc_d_heads, W * H, B * enc_n_heads);
|
||||
Q = ggml_view_3d (ctx0, cur, enc_n_embd, W*H, B, cur->nb[2], cur->nb[3], 0*cur->nb[1]);
|
||||
Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), enc_d_heads, enc_n_heads, W*H, B);
|
||||
Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // [B, enc_n_heads, H*W, enc_d_heads]
|
||||
|
||||
ggml_tensor * Kcur =
|
||||
ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], cur->nb[3]);
|
||||
Kcur = ggml_reshape_4d(ctx0, Kcur, enc_d_heads, enc_n_heads, W * H, B);
|
||||
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, enc_d_heads, W * H, B * enc_n_heads);
|
||||
K = ggml_view_3d (ctx0, cur, enc_n_embd, W*H, B, cur->nb[2], cur->nb[3], 1*cur->nb[1]);
|
||||
K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), enc_d_heads, enc_n_heads, W*H, B);
|
||||
K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // [B, enc_n_heads, H*W, enc_d_heads]
|
||||
|
||||
ggml_tensor * Vcur =
|
||||
ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 2 * cur->nb[3]);
|
||||
Vcur = ggml_reshape_4d(ctx0, Vcur, enc_d_heads, enc_n_heads, W * H, B);
|
||||
Vcur = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3)); // transposed
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, W * H, enc_d_heads, B * enc_n_heads);
|
||||
V = ggml_view_3d (ctx0, cur, enc_n_embd, W*H, B, cur->nb[2], cur->nb[3], 2*cur->nb[1]);
|
||||
V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), enc_d_heads, enc_n_heads, W*H, B);
|
||||
V = ggml_cont (ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3)); // [B, enc_n_heads, H*W, enc_d_heads]
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
ggml_tensor * mask;
|
||||
ggml_tensor * rw;
|
||||
ggml_tensor * rh;
|
||||
ggml_tensor * qr;
|
||||
|
||||
rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); // [W, W, C]
|
||||
rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); // [H, H, C]
|
||||
qr = ggml_reshape_4d(ctx0, Q, enc_d_heads, W, H, B*enc_n_heads);
|
||||
|
||||
const int WH_pad = GGML_PAD(W*H, GGML_KQ_MASK_PAD) - W*H;
|
||||
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcur, Qcur);
|
||||
rw = ggml_mul_mat (ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, qr, 0, 2, 1, 3))); // [B*enc_n_heads, W, H, W]
|
||||
rw = ggml_cont (ctx0, ggml_permute(ctx0, rw, 0, 2, 1, 3)); // [B*enc_n_heads, H, W, W]
|
||||
rw = ggml_reshape_4d(ctx0, rw, W, 1, W*H, enc_n_heads*B);
|
||||
rw = ggml_repeat_4d (ctx0, rw, W, H, W*H, enc_n_heads*B);
|
||||
rh = ggml_mul_mat (ctx0, rh, qr); // [B*enc_n_heads, H, W, H]
|
||||
rh = ggml_reshape_4d(ctx0, rh, 1, H, W*H, enc_n_heads*B);
|
||||
mask = ggml_add (ctx0, rw, rh); // [B*enc_n_heads, H*W, H, W]
|
||||
mask = ggml_reshape_4d(ctx0, mask, W*H, W*H, enc_n_heads, B);
|
||||
mask = ggml_pad (ctx0, mask, 0, WH_pad, 0, 0);
|
||||
mask = ggml_cast (ctx0, mask, GGML_TYPE_F16);
|
||||
|
||||
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_d_heads));
|
||||
|
||||
struct ggml_tensor * rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W);
|
||||
struct ggml_tensor * rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H);
|
||||
|
||||
struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads);
|
||||
|
||||
struct ggml_tensor * rel_w = ggml_cont(ctx0,ggml_permute(ctx0,
|
||||
ggml_mul_mat(ctx0,
|
||||
rw,
|
||||
ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))),
|
||||
0, 2, 1, 3));
|
||||
struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r);
|
||||
|
||||
struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h);
|
||||
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcur, KQ_soft_max);
|
||||
|
||||
cur = ggml_reshape_4d(
|
||||
ctx0,
|
||||
ggml_cont(ctx0, ggml_permute(ctx0, ggml_reshape_4d(ctx0, KQV, enc_d_heads, W * H, enc_n_heads, B),
|
||||
0, 2, 1, 3)),
|
||||
enc_n_embd, W, H, B);
|
||||
float scale = 1.0f / sqrtf((float)enc_d_heads);
|
||||
cur = ggml_flash_attn_ext(ctx0, Q, K, V, mask, scale, 0.0f, 0.0f); // [B, H*W, enc_n_heads, enc_d_heads]
|
||||
|
||||
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), enc_n_embd, W, H, B);
|
||||
cur = ggml_mul_mat(ctx0, layer.o_w, cur);
|
||||
cur = ggml_add_inplace(ctx0, cur, layer.o_b);
|
||||
}
|
||||
|
||||
if (hparams.is_global_attn(il) == false) {
|
||||
// local attention layer - reverse window partition
|
||||
cur = window_unpartition(ctx0, cur, w0, h0, 14);
|
||||
cur = window_unpartition(ctx0, cur, w0, h0, 14); // TODO: make window size configurable
|
||||
}
|
||||
|
||||
// re-add the layer input, e.g., residual
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
cur = ggml_add(ctx0, cur, shortcut);
|
||||
|
||||
ggml_tensor * inpFF = cur;
|
||||
|
||||
|
||||
cb(inpFF, "ffn_inp", il);
|
||||
|
||||
// layernorm2
|
||||
cur = build_norm(inpFF, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
|
||||
cb(cur, "ffn_inp_normed", il);
|
||||
|
||||
// ffn
|
||||
cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w,
|
||||
layer.ff_down_b, hparams.ffn_op, il);
|
||||
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
|
||||
// residual 2
|
||||
cur = ggml_add(ctx0, cur, inpFF);
|
||||
cb(cur, "layer_out", il);
|
||||
cb(cur, "sam_layer_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3));
|
||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
|
||||
|
||||
cur = ggml_conv_2d_sk_p0(ctx0, model.neck_0_w, cur);
|
||||
const int out_chans = model.neck_0_w->ne[3];
|
||||
|
||||
cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_1_w, model.neck_1_b, hparams.eps);
|
||||
cur = ggml_conv_2d(ctx0, model.neck_0_w, cur, 1, 1, 0, 0, 1, 1);
|
||||
cur = sam_layer_norm_2d(ctx0, cur, out_chans, model.neck_1_w, model.neck_1_b, hparams.eps);
|
||||
cur = ggml_conv_2d(ctx0, model.neck_2_w, cur, 1, 1, 1, 1, 1, 1);
|
||||
cur = sam_layer_norm_2d(ctx0, cur, out_chans, model.neck_3_w, model.neck_3_b, hparams.eps);
|
||||
|
||||
cur = ggml_conv_2d_s1_ph(ctx0, model.neck_2_w, cur);
|
||||
|
||||
cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_3_w, model.neck_3_b, hparams.eps);
|
||||
|
||||
cur = ggml_conv_2d(ctx0, model.net_2, cur, 2,2,1,1, 1,1);
|
||||
cur = ggml_conv_2d(ctx0, model.net_3, cur, 2,2,1,1, 1,1);
|
||||
cur = ggml_conv_2d(ctx0, model.net_2, cur, 2, 2, 1, 1, 1, 1);
|
||||
cur = ggml_conv_2d(ctx0, model.net_3, cur, 2, 2, 1, 1, 1, 1);
|
||||
cb(cur, "sam_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
return cur;
|
||||
|
|
@ -850,34 +838,32 @@ struct clip_graph {
|
|||
ggml_cgraph * build_deepseek_ocr() {
|
||||
//patch embedding
|
||||
ggml_tensor * inp_raw = build_inp_raw();
|
||||
|
||||
|
||||
ggml_tensor * global_features_1 = build_sam_enc(inp_raw, std::max(img.nx, img.ny));
|
||||
|
||||
ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1);
|
||||
|
||||
|
||||
// FIXME remove n_patches is hardcoded
|
||||
|
||||
|
||||
// torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
||||
global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1,2,1,0,3));
|
||||
global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1, 1, 2, 0, 3));
|
||||
int clip_n_patches = global_features_1->ne[1] * global_features_1->ne[2];
|
||||
|
||||
|
||||
// flatten 2nd and 3rd dims
|
||||
global_features_1 = ggml_reshape_2d(ctx0, global_features_1, global_features_1->ne[0], clip_n_patches);
|
||||
|
||||
|
||||
// remove CLS token
|
||||
global_features_2 = ggml_view_2d(ctx0, global_features_2,
|
||||
n_embd, clip_n_patches,
|
||||
ggml_row_size(global_features_2->type, n_embd), 0);
|
||||
|
||||
ggml_tensor * global_features = ggml_concat(ctx0, global_features_2, global_features_1, 1);
|
||||
global_features_2 = ggml_view_2d(ctx0, global_features_2, n_embd, clip_n_patches,
|
||||
global_features_2->nb[1], global_features_2->nb[1]);
|
||||
|
||||
ggml_tensor * global_features = ggml_concat(ctx0, global_features_2, global_features_1, 0);
|
||||
global_features = ggml_reshape_2d(ctx0, global_features, 2* n_embd,clip_n_patches);
|
||||
global_features = ggml_cont(ctx0, global_features);
|
||||
global_features = ggml_mul_mat(ctx0, model.fc_w, global_features);
|
||||
global_features = ggml_add(ctx0, global_features, model.fc_b);
|
||||
|
||||
global_features = build_global_local_features(ctx0,global_features);
|
||||
global_features = ggml_cont(ctx0, ggml_permute(ctx0, global_features, 1, 0, 2, 3));
|
||||
|
||||
cb(global_features, "dsocr_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, global_features);
|
||||
return gf;
|
||||
}
|
||||
|
|
@ -891,30 +877,23 @@ struct clip_graph {
|
|||
GGML_ASSERT(model.image_newline != nullptr);
|
||||
GGML_ASSERT(model.view_seperator != nullptr);
|
||||
|
||||
// 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim]
|
||||
const auto h = static_cast<int>(std::sqrt(static_cast<float>(global_features->ne[1])));
|
||||
const auto w = h;
|
||||
const auto n_dim = global_features->ne[0];
|
||||
ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, n_dim, h, w, 1); // (n_dim, w, h)
|
||||
t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim)
|
||||
ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3));
|
||||
nl = ggml_repeat_4d(ctx0, nl, h, 1, n_dim, 1); // n_pos rows
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * imgnl;
|
||||
ggml_tensor * vs;
|
||||
|
||||
// 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim]
|
||||
t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim)
|
||||
cur = ggml_reshape_3d(ctx0, global_features, n_dim, w, h);
|
||||
imgnl = ggml_repeat_4d(ctx0, model.image_newline, n_dim, 1, h, 1);
|
||||
cur = ggml_reshape_2d(ctx0, ggml_concat(ctx0, cur, imgnl, 1), n_dim, (w+1)*h);
|
||||
cb(cur, "insert_imgnl", -1);
|
||||
vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1)
|
||||
cur = ggml_concat(ctx0, cur, vs, 1); // (n_dim, h*(w+1) + 1)
|
||||
cb(cur, "insert_vs", -1);
|
||||
|
||||
t = ggml_reshape_2d(ctx0, t, n_dim, h* (h + 1)); // (n_dim, h*(w+1))
|
||||
|
||||
|
||||
// 5) append view_separator as an extra "token":
|
||||
// view_separator: [n_dim] -> [n_dim, 1]
|
||||
ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1)
|
||||
|
||||
// concat along token dimension (dim=1):
|
||||
t = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1)
|
||||
|
||||
return t;
|
||||
return cur;
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -1569,8 +1548,8 @@ struct clip_graph {
|
|||
ggml_tensor * inp = ggml_cpy(ctx0, patch_embeds, ggml_dup_tensor(ctx0, patch_embeds));
|
||||
|
||||
|
||||
inp = ggml_cont(ctx0,ggml_permute(ctx0, inp,2,1,0,3));
|
||||
inp = ggml_reshape_2d(ctx0, inp, n_embd, inp->ne[1]*inp->ne[2]*inp->ne[3]);
|
||||
inp = ggml_reshape_2d(ctx0, inp, inp->ne[0]*inp->ne[1], inp->ne[2]);
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
|
||||
|
||||
ggml_tensor * new_pos_embd = ggml_cpy(ctx0, model.position_embeddings, ggml_dup_tensor(ctx0, model.position_embeddings));
|
||||
|
||||
|
|
@ -1601,7 +1580,7 @@ struct clip_graph {
|
|||
|
||||
|
||||
// add CLS token
|
||||
inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
|
||||
inp = ggml_concat(ctx0, model.class_embedding, inp, 1);
|
||||
|
||||
//TODO : check norm type for dp-ocr-clip
|
||||
norm_type norm_t = NORM_TYPE_NORMAL;
|
||||
|
|
@ -1610,9 +1589,8 @@ struct clip_graph {
|
|||
ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32);
|
||||
ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions);
|
||||
|
||||
|
||||
ggml_tensor * cur = build_vit(inp, n_pos, norm_t, hparams.ffn_op, learned_pos_embd,
|
||||
nullptr); // shape [1024, 16, 16]
|
||||
ggml_tensor * cur = build_vit(inp, n_pos, norm_t, ffn_op_type::FFN_GELU_QUICK,
|
||||
learned_pos_embd, nullptr); // shape [1024, 16, 16]
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
|
|
@ -2576,19 +2554,27 @@ private:
|
|||
|
||||
if (max_rel_dist != L) {
|
||||
// Linear interpolation
|
||||
const auto scale = L / static_cast<float>(max_rel_dist);
|
||||
ggml_tensor * indices = ggml_arange(ctx, 0.0f, static_cast<float>(max_rel_dist), 1.0f);
|
||||
indices = ggml_scale_inplace(ctx, indices, scale);
|
||||
ggml_tensor * indices_floor= ggml_cast(ctx, ggml_floor(ctx, indices), GGML_TYPE_I32);
|
||||
ggml_tensor * indices_ceil = ggml_cast(ctx, ggml_ceil(ctx, indices), GGML_TYPE_I32);
|
||||
ggml_tensor * weights = ggml_sub(ctx, indices, indices_floor);
|
||||
ggml_tensor * ws1 = ggml_scale_bias(ctx, weights, -1.0f, 1.0f);
|
||||
rel_pos_resized = ggml_cont(ctx , ggml_permute(ctx, rel_pos_resized, 1, 0, 2, 3)); // [C, L] for ggml_get_rows
|
||||
ggml_tensor * rs1 = ggml_cont(ctx, ggml_permute(ctx, ggml_get_rows(ctx, rel_pos_resized, indices_floor), 1, 0, 2, 3)); // lower rows
|
||||
rs1 = ggml_mul(ctx, rs1, ws1); // lower rows
|
||||
ggml_tensor * rs2 = ggml_cont(ctx, ggml_permute(ctx, ggml_get_rows(ctx, rel_pos_resized, indices_ceil), 1, 0, 2, 3)); // upper rows
|
||||
rs2 = ggml_mul(ctx, rs2, weights); // upper rows
|
||||
rel_pos_resized = ggml_add(ctx,rs1, rs2);
|
||||
int64_t ne0 = rel_pos_resized->ne[0];
|
||||
int64_t ne1 = rel_pos_resized->ne[1];
|
||||
int64_t ne2 = rel_pos_resized->ne[2];
|
||||
int64_t ne3 = rel_pos_resized->ne[3];
|
||||
|
||||
rel_pos_resized = ggml_reshape_3d(
|
||||
ctx,
|
||||
ggml_cont(ctx, ggml_permute(ctx, rel_pos_resized, 1, 0, 2, 3)),
|
||||
ne1, 1, ne0*ne2*ne3
|
||||
);
|
||||
rel_pos_resized = ggml_reshape_4d(
|
||||
ctx,
|
||||
ggml_interpolate(
|
||||
ctx,
|
||||
rel_pos_resized,
|
||||
max_rel_dist, 1, ne0*ne2*ne3, 1,
|
||||
ggml_scale_mode::GGML_SCALE_MODE_BILINEAR
|
||||
),
|
||||
max_rel_dist, ne0, ne2, ne3
|
||||
);
|
||||
rel_pos_resized = ggml_cont(ctx, ggml_permute(ctx, rel_pos_resized, 1, 0, 2, 3));
|
||||
}
|
||||
|
||||
// -------------------------------------------------
|
||||
|
|
@ -2627,7 +2613,7 @@ private:
|
|||
rel = ggml_sub(ctx, q_coord, k_coord); // [q_size, k_size]
|
||||
rel = ggml_scale_bias(ctx, rel, 1.0f, (k_size - 1.0f)*k_scale); // [q_size, k_size]
|
||||
// Clamp to [0, L-1] range for valid indexing
|
||||
rel = ggml_clamp(ctx, rel, 0.0f, static_cast<float>(rel_pos->ne[1] - 1));
|
||||
rel = ggml_clamp(ctx, rel, 0.0f, static_cast<float>(rel_pos_resized->ne[1] - 1));
|
||||
|
||||
// -------------------------------------------------
|
||||
// clamp to [0, L-1] and cast to int32 (for ggml_get_rows)
|
||||
|
|
@ -2641,7 +2627,7 @@ private:
|
|||
// flatten to 1D for ggml_get_rows
|
||||
int qk = q_size * k_size;
|
||||
ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk]
|
||||
ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C]
|
||||
ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos_resized, idx_flat); // [qk, C]
|
||||
|
||||
// -------------------------------------------------
|
||||
// Gather from rel_pos → [qk, C]
|
||||
|
|
@ -2671,7 +2657,7 @@ private:
|
|||
}
|
||||
x = ggml_reshape_4d(ctx, x, c * window, npw, window, nph * b);
|
||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3));
|
||||
x = ggml_reshape_4d(ctx, x, c, window ,window, npw * nph * b);
|
||||
x = ggml_reshape_4d(ctx, x, c, window, window, npw * nph * b);
|
||||
return x;
|
||||
}
|
||||
|
||||
|
|
@ -3419,7 +3405,6 @@ struct clip_model_loader {
|
|||
hparams.patch_size = 16;
|
||||
hparams.image_size = 1024;
|
||||
hparams.warmup_image_size = 1024;
|
||||
hparams.crop_mode = false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
|
|
@ -5070,9 +5055,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
|||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_DEEPSEEKOCR:
|
||||
if (!params.crop_mode) {
|
||||
/* Native Resolution (Tiny/Small/Base/Large) */
|
||||
|
||||
{
|
||||
const int native_resolutions[] = {
|
||||
512 /* tiny */, 640 /* small */, 1024 /* base */, 1280 /* large */
|
||||
};
|
||||
|
|
@ -5080,29 +5063,49 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
|||
const int orig_w = original_size.width;
|
||||
const int orig_h = original_size.height;
|
||||
const int orig_area = orig_h * orig_w;
|
||||
|
||||
// mode selection logic (find most suitable resolution)
|
||||
std::array<uint8_t, 3u> color;
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
color[i] = (int)(255 * params.image_mean[i]);
|
||||
}
|
||||
|
||||
int mode_i = 0;
|
||||
int min_diff = orig_area;
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int r = native_resolutions[i];
|
||||
if (std::abs(orig_area - r*r) < min_diff) {
|
||||
mode_i = i;
|
||||
min_diff = std::abs(orig_area - r*r);
|
||||
|
||||
if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_TINY) {
|
||||
mode_i = 0;
|
||||
} else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_SMALL) {
|
||||
mode_i = 1;
|
||||
} else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_BASE) {
|
||||
mode_i = 2;
|
||||
} else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_LARGE) {
|
||||
mode_i = 3;
|
||||
} else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_GUNDAM) {
|
||||
mode_i = 4;
|
||||
} else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_GUNDAM_MASTER) {
|
||||
mode_i = 5;
|
||||
} else {
|
||||
if (params.dsocr_mode != clip_dsocr_mode::CLIP_DSOCR_MODE_AUTO) {
|
||||
LOG_WRN("%s: unknown dsocr_mode, using auto mode\n", __func__);
|
||||
}
|
||||
int min_diff = orig_area;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int r = native_resolutions[i];
|
||||
if (std::abs(orig_area - r*r) < min_diff) {
|
||||
mode_i = i;
|
||||
min_diff = std::abs(orig_area - r*r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const int image_size = native_resolutions[mode_i];
|
||||
|
||||
if (mode_i < 2) {
|
||||
// TINY/SMALL MODE: Direct resize (no slicing)
|
||||
/* Native Resolution (Tiny/Small) */
|
||||
const int image_size = native_resolutions[mode_i];
|
||||
|
||||
// Just resize the image to image_size × image_size
|
||||
|
||||
clip_image_u8_ptr resized_img(clip_image_u8_init());
|
||||
img_tool::resize(*img, *resized_img,
|
||||
clip_image_size{image_size, image_size},
|
||||
img_tool::RESIZE_ALGO_BICUBIC); // Match PIL default
|
||||
img_tool::RESIZE_ALGO_BICUBIC, true, color); // Match PIL default
|
||||
|
||||
clip_image_f32_ptr res(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(*resized_img, *res, params.image_mean, params.image_std);
|
||||
|
|
@ -5111,10 +5114,11 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
|||
res_imgs->grid_x = 1;
|
||||
res_imgs->grid_y = 1;
|
||||
}
|
||||
else {
|
||||
// BASE/LARGE MODE: Resize with aspect ratio + padding
|
||||
else if (mode_i < 4) {
|
||||
/* Native Resolution (Base/Large) */
|
||||
const int image_size = native_resolutions[mode_i];
|
||||
|
||||
// Resize maintaining aspect ratio, then pad to square
|
||||
|
||||
float scale = std::min(
|
||||
static_cast<float>(image_size) / orig_w,
|
||||
static_cast<float>(image_size) / orig_h
|
||||
|
|
@ -5124,14 +5128,14 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
|||
|
||||
clip_image_u8_ptr scaled_img(clip_image_u8_init());
|
||||
img_tool::resize(*img, *scaled_img, clip_image_size{new_w, new_h},
|
||||
img_tool::RESIZE_ALGO_BICUBIC);
|
||||
img_tool::RESIZE_ALGO_BICUBIC, true, color);
|
||||
|
||||
// Use mean color for padding
|
||||
unsigned char pad_r = static_cast<unsigned char>(params.image_mean[0] * 255.0f);
|
||||
unsigned char pad_g = static_cast<unsigned char>(params.image_mean[1] * 255.0f);
|
||||
unsigned char pad_b = static_cast<unsigned char>(params.image_mean[2] * 255.0f);
|
||||
|
||||
// Step 2: Pad to image_size × image_size (center padding)
|
||||
// Pad to image_size × image_size (center padding)
|
||||
clip_image_u8_ptr padded_img(clip_image_u8_init());
|
||||
padded_img->nx = image_size;
|
||||
padded_img->ny = image_size;
|
||||
|
|
@ -5159,7 +5163,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
|||
}
|
||||
}
|
||||
|
||||
// Step 3: Normalize and output
|
||||
// Normalize and output
|
||||
clip_image_f32_ptr res(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(*padded_img, *res, params.image_mean, params.image_std);
|
||||
res_imgs->entries.push_back(std::move(res));
|
||||
|
|
@ -5167,68 +5171,69 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
|||
res_imgs->grid_x = 1;
|
||||
res_imgs->grid_y = 1;
|
||||
}
|
||||
}
|
||||
else {
|
||||
/* Dynamic Resolution (Gundam/Gundam-M) */
|
||||
|
||||
// configurable, or read from params
|
||||
const int min_num = 2;
|
||||
const int max_num = 9;
|
||||
const int image_size = params.image_size; // typically 640
|
||||
// const bool use_thumbnail = true; // mimic python's use_thumbnail
|
||||
|
||||
// original image size
|
||||
const int orig_w = original_size.width;
|
||||
const int orig_h = original_size.height;
|
||||
|
||||
// 1) build candidate grids (cols, rows)
|
||||
auto target_ratios = ds_build_target_ratios(min_num, max_num);
|
||||
|
||||
// 2) pick the grid that best matches the original aspect ratio
|
||||
const float aspect_ratio = static_cast<float>(orig_w) / static_cast<float>(orig_h);
|
||||
auto best = ds_find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size);
|
||||
const int grid_cols = best.first; // how many tiles horizontally
|
||||
const int grid_rows = best.second; // how many tiles vertically
|
||||
|
||||
// 3) compute the target (forced) size — python did:
|
||||
// target_width = image_size * cols
|
||||
// target_height = image_size * rows
|
||||
const clip_image_size refined_size{ image_size * grid_cols, image_size * grid_rows };
|
||||
|
||||
// 4) prepare slice instructions, same style as the idefics3 branch
|
||||
llava_uhd::slice_instructions instructions;
|
||||
instructions.overview_size = clip_image_size{ image_size, image_size }; // for thumbnail/global
|
||||
instructions.refined_size = refined_size;
|
||||
instructions.grid_size = clip_image_size{ grid_cols, grid_rows };
|
||||
|
||||
// in deepseek python they always produce *full* 640x640 blocks,
|
||||
// so we can do a simple double loop over rows/cols:
|
||||
for (int r = 0; r < grid_rows; ++r) {
|
||||
for (int c = 0; c < grid_cols; ++c) {
|
||||
const int x = c * image_size;
|
||||
const int y = r * image_size;
|
||||
|
||||
instructions.slices.push_back(llava_uhd::slice_coordinates{
|
||||
/* x */ x,
|
||||
/* y */ y,
|
||||
/* size */ clip_image_size{ image_size, image_size }
|
||||
});
|
||||
else {
|
||||
GGML_ABORT("DeepSeek-OCR: Gundam/Gundam-Master haven't been tested yet.\n");
|
||||
/* Dynamic Resolution (Gundam/Gundam-Master) */
|
||||
|
||||
// configurable, or read from params
|
||||
const int min_num = 2;
|
||||
const int max_num = 9;
|
||||
const int image_size = params.image_size; // typically 640
|
||||
// const bool use_thumbnail = true; // mimic python's use_thumbnail
|
||||
|
||||
// original image size
|
||||
const int orig_w = original_size.width;
|
||||
const int orig_h = original_size.height;
|
||||
|
||||
// 1) build candidate grids (cols, rows)
|
||||
auto target_ratios = ds_build_target_ratios(min_num, max_num);
|
||||
|
||||
// 2) pick the grid that best matches the original aspect ratio
|
||||
const float aspect_ratio = static_cast<float>(orig_w) / static_cast<float>(orig_h);
|
||||
auto best = ds_find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size);
|
||||
const int grid_cols = best.first; // how many tiles horizontally
|
||||
const int grid_rows = best.second; // how many tiles vertically
|
||||
|
||||
// 3) compute the target (forced) size — python did:
|
||||
// target_width = image_size * cols
|
||||
// target_height = image_size * rows
|
||||
const clip_image_size refined_size{ image_size * grid_cols, image_size * grid_rows };
|
||||
|
||||
// 4) prepare slice instructions, same style as the idefics3 branch
|
||||
llava_uhd::slice_instructions instructions;
|
||||
instructions.overview_size = clip_image_size{ image_size, image_size }; // for thumbnail/global
|
||||
instructions.refined_size = refined_size;
|
||||
instructions.grid_size = clip_image_size{ grid_cols, grid_rows };
|
||||
|
||||
// in deepseek python they always produce *full* 640x640 blocks,
|
||||
// so we can do a simple double loop over rows/cols:
|
||||
for (int r = 0; r < grid_rows; ++r) {
|
||||
for (int c = 0; c < grid_cols; ++c) {
|
||||
const int x = c * image_size;
|
||||
const int y = r * image_size;
|
||||
|
||||
instructions.slices.push_back(llava_uhd::slice_coordinates{
|
||||
/* x */ x,
|
||||
/* y */ y,
|
||||
/* size */ clip_image_size{ image_size, image_size }
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 5) run the actual slicing (this should: resize to refined_size, then crop every slice)
|
||||
auto imgs = llava_uhd::slice_image(img, instructions);
|
||||
|
||||
// 7) cast & normalize like the idefics3 branch
|
||||
for (size_t i = 0; i < imgs.size(); ++i) {
|
||||
clip_image_f32_ptr res(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
|
||||
res_imgs->entries.push_back(std::move(res));
|
||||
}
|
||||
|
||||
// keep the grid info — the model may need to know how to reassemble / attend
|
||||
res_imgs->grid_x = grid_cols;
|
||||
res_imgs->grid_y = grid_rows;
|
||||
}
|
||||
|
||||
// 5) run the actual slicing (this should: resize to refined_size, then crop every slice)
|
||||
auto imgs = llava_uhd::slice_image(img, instructions);
|
||||
|
||||
// 7) cast & normalize like the idefics3 branch
|
||||
for (size_t i = 0; i < imgs.size(); ++i) {
|
||||
clip_image_f32_ptr res(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
|
||||
res_imgs->entries.push_back(std::move(res));
|
||||
}
|
||||
|
||||
// keep the grid info — the model may need to know how to reassemble / attend
|
||||
res_imgs->grid_x = grid_cols;
|
||||
res_imgs->grid_y = grid_rows;
|
||||
}
|
||||
break;
|
||||
|
||||
|
|
@ -5415,12 +5420,15 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
|||
} break;
|
||||
case PROJECTOR_TYPE_DEEPSEEKOCR:
|
||||
{
|
||||
int x_patch = img->nx / (params.patch_size);
|
||||
|
||||
n_patches += x_patch + 1;
|
||||
n_patches = 1280;
|
||||
|
||||
// SAM encoder applies two stride-2 convolutions (net_2 and net_3)
|
||||
// which reduces spatial dimensions by 4x in each direction (16x total)
|
||||
// E.g., 64x64 -> 16x16 patches
|
||||
n_patches /= 16;
|
||||
|
||||
// build_global_local_features adds image newlines and view separator
|
||||
// Formula: h*(w+1) + 1 where h = w = sqrt(n_patches)
|
||||
int h = static_cast<int>(std::sqrt(static_cast<float>(n_patches)));
|
||||
n_patches = h * (h + 1) + 1;
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported projector type");
|
||||
|
|
@ -5803,8 +5811,27 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
for (ggml_tensor * t : ctx->debug_print_tensors) {
|
||||
std::vector<uint8_t> data(ggml_nbytes(t));
|
||||
ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t));
|
||||
print_tensor_info(t);
|
||||
print_tensor_shape(t);
|
||||
print_tensor_data(t, data.data(), 3);
|
||||
print_tensor_sum(t, data.data(), 3);
|
||||
std::string tname_s = std::string(t->name);
|
||||
|
||||
bool is_stored = false;
|
||||
std::vector<std::string> patterns = {
|
||||
/* Add tensor names here to dump (e.g. "sam_output") */
|
||||
};
|
||||
|
||||
for (auto & p : patterns) {
|
||||
if (tname_s == p) {
|
||||
save_tensor_to_file(t, data.data());
|
||||
is_stored = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_stored) {
|
||||
print_tensor_data(t, data.data(), 3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -29,11 +29,22 @@ enum clip_flash_attn_type {
|
|||
CLIP_FLASH_ATTN_TYPE_ENABLED = 1,
|
||||
};
|
||||
|
||||
enum clip_dsocr_mode {
|
||||
CLIP_DSOCR_MODE_AUTO,
|
||||
CLIP_DSOCR_MODE_TINY,
|
||||
CLIP_DSOCR_MODE_SMALL,
|
||||
CLIP_DSOCR_MODE_BASE,
|
||||
CLIP_DSOCR_MODE_LARGE,
|
||||
CLIP_DSOCR_MODE_GUNDAM,
|
||||
CLIP_DSOCR_MODE_GUNDAM_MASTER,
|
||||
};
|
||||
|
||||
struct clip_context_params {
|
||||
bool use_gpu;
|
||||
enum clip_flash_attn_type flash_attn_type;
|
||||
int image_min_tokens;
|
||||
int image_max_tokens;
|
||||
enum clip_dsocr_mode dsocr_mode;
|
||||
};
|
||||
|
||||
struct clip_init_result {
|
||||
|
|
|
|||
|
|
@ -138,6 +138,7 @@ struct mtmd_cli_context {
|
|||
mparams.flash_attn_type = params.flash_attn_type;
|
||||
mparams.image_min_tokens = params.image_min_tokens;
|
||||
mparams.image_max_tokens = params.image_max_tokens;
|
||||
mparams.dsocr_mode = params.dsocr_mode.c_str();
|
||||
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
|
||||
if (!ctx_vision.get()) {
|
||||
LOG_ERR("Failed to load vision model from %s\n", clip_path);
|
||||
|
|
|
|||
|
|
@ -110,6 +110,7 @@ mtmd_context_params mtmd_context_params_default() {
|
|||
/* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO,
|
||||
/* image_min_tokens */ -1,
|
||||
/* image_max_tokens */ -1,
|
||||
/* dsocr_mode */ "auto",
|
||||
};
|
||||
return params;
|
||||
}
|
||||
|
|
@ -172,11 +173,32 @@ struct mtmd_context {
|
|||
throw std::runtime_error("media_marker must not be empty");
|
||||
}
|
||||
|
||||
enum clip_dsocr_mode dsocr_mode;
|
||||
|
||||
if (std::string(ctx_params.dsocr_mode) == "auto") {
|
||||
dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_AUTO;
|
||||
} else if (std::string(ctx_params.dsocr_mode) == "tiny") {
|
||||
dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_TINY;
|
||||
} else if (std::string(ctx_params.dsocr_mode) == "small") {
|
||||
dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_SMALL;
|
||||
} else if (std::string(ctx_params.dsocr_mode) == "base") {
|
||||
dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_BASE;
|
||||
} else if (std::string(ctx_params.dsocr_mode) == "large") {
|
||||
dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_LARGE;
|
||||
} else if (std::string(ctx_params.dsocr_mode) == "gundam") {
|
||||
dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_GUNDAM;
|
||||
} else if (std::string(ctx_params.dsocr_mode) == "gundam-master") {
|
||||
dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_GUNDAM_MASTER;
|
||||
} else {
|
||||
throw std::invalid_argument("invalid value");
|
||||
}
|
||||
|
||||
clip_context_params ctx_clip_params {
|
||||
/* use_gpu */ ctx_params.use_gpu,
|
||||
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO,
|
||||
/* image_min_tokens */ ctx_params.image_min_tokens,
|
||||
/* image_max_tokens */ ctx_params.image_max_tokens,
|
||||
/* dsocr_mode */ dsocr_mode,
|
||||
};
|
||||
|
||||
auto res = clip_init(mmproj_fname, ctx_clip_params);
|
||||
|
|
|
|||
|
|
@ -86,6 +86,9 @@ struct mtmd_context_params {
|
|||
// limit number of image tokens, only for vision models with dynamic resolution
|
||||
int image_min_tokens; // minimum number of tokens for image input (default: read from metadata)
|
||||
int image_max_tokens; // maximum number of tokens for image input (default: read from metadata)
|
||||
|
||||
// DeepSeek-OCR resolution mode
|
||||
const char * dsocr_mode; // one of: auto, tiny, small, base, large, gundam, gundam-master
|
||||
};
|
||||
|
||||
MTMD_API const char * mtmd_default_marker(void);
|
||||
|
|
|
|||
Loading…
Reference in New Issue