- dynamic resizing
- changes are concerning PR https://github.com/sfallah/llama.cpp/pull/4
This commit is contained in:
parent
7941f5d8ff
commit
206f8abc3c
|
|
@ -676,7 +676,25 @@ struct clip_graph {
|
||||||
const int enc_n_patches = enc_image_size / enc_patch_size; // 64
|
const int enc_n_patches = enc_image_size / enc_patch_size; // 64
|
||||||
|
|
||||||
ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_n_patches, enc_n_embd);
|
ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_n_patches, enc_n_embd);
|
||||||
ggml_tensor * cur = ggml_add(ctx0, inpL, model.pos_embed);
|
ggml_tensor * cur = nullptr;
|
||||||
|
|
||||||
|
const auto tgt_size = inpL->ne[1];
|
||||||
|
const auto str_size = model.pos_embed->ne[1];
|
||||||
|
if (str_size != tgt_size) {
|
||||||
|
ggml_tensor * new_pos_embed = ggml_interpolate(
|
||||||
|
ctx0,
|
||||||
|
model.pos_embed,
|
||||||
|
tgt_size,
|
||||||
|
tgt_size,
|
||||||
|
enc_n_embd,
|
||||||
|
1,
|
||||||
|
ggml_scale_mode::GGML_SCALE_MODE_BICUBIC
|
||||||
|
);
|
||||||
|
new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 2,1,0,3));
|
||||||
|
cur = ggml_add(ctx0, inpL, new_pos_embed);
|
||||||
|
} else {
|
||||||
|
cur = ggml_add(ctx0, inpL, model.pos_embed);
|
||||||
|
}
|
||||||
|
|
||||||
// loop over layers
|
// loop over layers
|
||||||
for (int il = 0; il < _depth; il++) {
|
for (int il = 0; il < _depth; il++) {
|
||||||
|
|
@ -840,10 +858,11 @@ struct clip_graph {
|
||||||
ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1);
|
ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1);
|
||||||
|
|
||||||
// FIXME remove n_patches is hardcoded
|
// FIXME remove n_patches is hardcoded
|
||||||
int clip_n_patches = 256; // FIXME hardcoded for sam 1024x1024 with 16x16 patches
|
|
||||||
|
|
||||||
// torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
// torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
||||||
global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1,2,1,0,3));
|
global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1,2,1,0,3));
|
||||||
|
int clip_n_patches = global_features_1->ne[1] * global_features_1->ne[2];
|
||||||
|
|
||||||
// flatten 2nd and 3rd dims
|
// flatten 2nd and 3rd dims
|
||||||
global_features_1 = ggml_reshape_2d(ctx0, global_features_1, global_features_1->ne[0], clip_n_patches);
|
global_features_1 = ggml_reshape_2d(ctx0, global_features_1, global_features_1->ne[0], clip_n_patches);
|
||||||
|
|
||||||
|
|
@ -874,21 +893,24 @@ struct clip_graph {
|
||||||
GGML_ASSERT(model.view_seperator != nullptr);
|
GGML_ASSERT(model.view_seperator != nullptr);
|
||||||
|
|
||||||
// 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim]
|
// 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim]
|
||||||
ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, 1280, 16, 16, 1); // (n_dim, w, h)
|
const auto h = static_cast<int>(std::sqrt(static_cast<float>(global_features->ne[1])));
|
||||||
|
const auto w = h;
|
||||||
|
const auto n_dim = global_features->ne[0];
|
||||||
|
ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, n_dim, h, w, 1); // (n_dim, w, h)
|
||||||
t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim)
|
t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim)
|
||||||
ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3));
|
ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3));
|
||||||
nl = ggml_repeat_4d(ctx0, nl, 16, 1, 1280, 1); // n_pos rows
|
nl = ggml_repeat_4d(ctx0, nl, h, 1, n_dim, 1); // n_pos rows
|
||||||
|
|
||||||
|
|
||||||
// 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim]
|
// 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim]
|
||||||
t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim)
|
t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim)
|
||||||
|
|
||||||
t = ggml_reshape_2d(ctx0, t, 1280, 16 * (16 + 1)); // (n_dim, h*(w+1))
|
t = ggml_reshape_2d(ctx0, t, n_dim, h* (h + 1)); // (n_dim, h*(w+1))
|
||||||
|
|
||||||
|
|
||||||
// 5) append view_separator as an extra "token":
|
// 5) append view_separator as an extra "token":
|
||||||
// view_separator: [n_dim] -> [n_dim, 1]
|
// view_separator: [n_dim] -> [n_dim, 1]
|
||||||
ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, 1280, 1); // (n_dim, 1)
|
ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1)
|
||||||
|
|
||||||
// concat along token dimension (dim=1):
|
// concat along token dimension (dim=1):
|
||||||
t = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1)
|
t = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1)
|
||||||
|
|
@ -1547,10 +1569,35 @@ struct clip_graph {
|
||||||
ggml_tensor * inp = ggml_cpy(ctx0, patch_embeds, ggml_dup_tensor(ctx0, patch_embeds));
|
ggml_tensor * inp = ggml_cpy(ctx0, patch_embeds, ggml_dup_tensor(ctx0, patch_embeds));
|
||||||
|
|
||||||
|
|
||||||
const int n_pos = 257; // +1 for [CLS]
|
|
||||||
inp = ggml_cont(ctx0,ggml_permute(ctx0, inp,2,1,0,3));
|
inp = ggml_cont(ctx0,ggml_permute(ctx0, inp,2,1,0,3));
|
||||||
inp = ggml_reshape_2d(ctx0, inp, n_embd, inp->ne[1]*inp->ne[2]*inp->ne[3]);
|
inp = ggml_reshape_2d(ctx0, inp, n_embd, inp->ne[1]*inp->ne[2]*inp->ne[3]);
|
||||||
|
|
||||||
|
ggml_tensor * new_pos_embd = ggml_cpy(ctx0, model.position_embeddings, ggml_dup_tensor(ctx0, model.position_embeddings));
|
||||||
|
|
||||||
|
int n_pos = new_pos_embd->ne[1]; // +1 for [CLS]
|
||||||
|
const auto tgt_size = static_cast<int>(std::sqrt(inp->ne[1]));
|
||||||
|
const auto src_size = static_cast<int>(std::sqrt(n_pos - 1));
|
||||||
|
|
||||||
|
|
||||||
|
if (tgt_size != src_size) {
|
||||||
|
//ggml_tensor * old_pos_embd = ggml_new_tensor_2d(ctx0, model.position_embeddings->type, model.position_embeddings->ne[0], str_size * str_size);
|
||||||
|
ggml_tensor * old_pos_embd = ggml_view_2d(ctx0, new_pos_embd,
|
||||||
|
new_pos_embd->ne[0], src_size * src_size,
|
||||||
|
ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), 0);
|
||||||
|
ggml_tensor * cls_tok = ggml_view_2d(ctx0, new_pos_embd,
|
||||||
|
new_pos_embd->ne[0], 1,
|
||||||
|
ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), src_size * src_size);
|
||||||
|
new_pos_embd = ggml_interpolate(ctx0,
|
||||||
|
old_pos_embd,
|
||||||
|
tgt_size,
|
||||||
|
tgt_size,
|
||||||
|
new_pos_embd->ne[0], 1, GGML_SCALE_MODE_BICUBIC);
|
||||||
|
new_pos_embd = ggml_reshape_3d(ctx0, new_pos_embd, n_embd, tgt_size * tgt_size, 1);
|
||||||
|
//new_pos_embd = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embd, 2,1,0,3));
|
||||||
|
new_pos_embd = ggml_concat(ctx0, new_pos_embd, cls_tok, 1);
|
||||||
|
n_pos = tgt_size * tgt_size + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// add CLS token
|
// add CLS token
|
||||||
|
|
@ -1560,11 +1607,8 @@ struct clip_graph {
|
||||||
norm_type norm_t = NORM_TYPE_NORMAL;
|
norm_type norm_t = NORM_TYPE_NORMAL;
|
||||||
|
|
||||||
// for selecting learned pos embd, used by ViT
|
// for selecting learned pos embd, used by ViT
|
||||||
ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
|
ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32);
|
||||||
ggml_set_name(positions, "positions");
|
ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions);
|
||||||
ggml_set_input(positions);
|
|
||||||
|
|
||||||
ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, model.position_embeddings, positions);
|
|
||||||
|
|
||||||
|
|
||||||
ggml_tensor * cur = build_vit(inp, n_pos, norm_t, hparams.ffn_op, learned_pos_embd,
|
ggml_tensor * cur = build_vit(inp, n_pos, norm_t, hparams.ffn_op, learned_pos_embd,
|
||||||
|
|
@ -2525,7 +2569,27 @@ private:
|
||||||
const int64_t C = rel_pos->ne[0]; // channels
|
const int64_t C = rel_pos->ne[0]; // channels
|
||||||
const int64_t L = rel_pos->ne[1]; // length
|
const int64_t L = rel_pos->ne[1]; // length
|
||||||
|
|
||||||
GGML_ASSERT(2*std::max(q_size, k_size) - 1 == L);
|
//GGML_ASSERT(2*std::max(q_size, k_size) - 1 == L);
|
||||||
|
|
||||||
|
const auto max_rel_dist = 2*std::max(q_size, k_size) - 1;
|
||||||
|
ggml_tensor * rel_pos_resized = rel_pos;
|
||||||
|
|
||||||
|
if (max_rel_dist != L) {
|
||||||
|
// Linear interpolation
|
||||||
|
const auto scale = L / static_cast<float>(max_rel_dist);
|
||||||
|
ggml_tensor * indices = ggml_arange(ctx, 0.0f, static_cast<float>(max_rel_dist), 1.0f);
|
||||||
|
indices = ggml_scale_inplace(ctx, indices, scale);
|
||||||
|
ggml_tensor * indices_floor= ggml_cast(ctx, ggml_floor(ctx, indices), GGML_TYPE_I32);
|
||||||
|
ggml_tensor * indices_ceil = ggml_cast(ctx, ggml_ceil(ctx, indices), GGML_TYPE_I32);
|
||||||
|
ggml_tensor * weights = ggml_sub(ctx, indices, indices_floor);
|
||||||
|
ggml_tensor * ws1 = ggml_scale_bias(ctx, weights, -1.0f, 1.0f);
|
||||||
|
rel_pos_resized = ggml_cont(ctx , ggml_permute(ctx, rel_pos_resized, 1, 0, 2, 3)); // [C, L] for ggml_get_rows
|
||||||
|
ggml_tensor * rs1 = ggml_cont(ctx, ggml_permute(ctx, ggml_get_rows(ctx, rel_pos_resized, indices_floor), 1, 0, 2, 3)); // lower rows
|
||||||
|
rs1 = ggml_mul(ctx, rs1, ws1); // lower rows
|
||||||
|
ggml_tensor * rs2 = ggml_cont(ctx, ggml_permute(ctx, ggml_get_rows(ctx, rel_pos_resized, indices_ceil), 1, 0, 2, 3)); // upper rows
|
||||||
|
rs2 = ggml_mul(ctx, rs2, weights); // upper rows
|
||||||
|
rel_pos_resized = ggml_add(ctx,rs1, rs2);
|
||||||
|
}
|
||||||
|
|
||||||
// -------------------------------------------------
|
// -------------------------------------------------
|
||||||
// 1) q_idx ← arange(0..q_size-1) [q_size]
|
// 1) q_idx ← arange(0..q_size-1) [q_size]
|
||||||
|
|
@ -5007,7 +5071,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||||
if (!params.crop_mode) {
|
if (!params.crop_mode) {
|
||||||
/* Native Resolution (Tiny/Small/Base/Large) */
|
/* Native Resolution (Tiny/Small/Base/Large) */
|
||||||
|
|
||||||
const int native_resolutions[] = {
|
const int native_resolutions[] = {
|
||||||
512 /* tiny */, 640 /* small */, 1024 /* base */, 1280 /* large */
|
512 /* tiny */, 640 /* small */, 1024 /* base */, 1280 /* large */
|
||||||
};
|
};
|
||||||
// original image size
|
// original image size
|
||||||
|
|
@ -5060,7 +5124,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||||
img_tool::resize(*img, *scaled_img, clip_image_size{new_w, new_h},
|
img_tool::resize(*img, *scaled_img, clip_image_size{new_w, new_h},
|
||||||
img_tool::RESIZE_ALGO_BICUBIC);
|
img_tool::RESIZE_ALGO_BICUBIC);
|
||||||
|
|
||||||
// Use mean color for padding
|
// Use mean color for padding
|
||||||
unsigned char pad_r = static_cast<unsigned char>(params.image_mean[0] * 255.0f);
|
unsigned char pad_r = static_cast<unsigned char>(params.image_mean[0] * 255.0f);
|
||||||
unsigned char pad_g = static_cast<unsigned char>(params.image_mean[1] * 255.0f);
|
unsigned char pad_g = static_cast<unsigned char>(params.image_mean[1] * 255.0f);
|
||||||
unsigned char pad_b = static_cast<unsigned char>(params.image_mean[2] * 255.0f);
|
unsigned char pad_b = static_cast<unsigned char>(params.image_mean[2] * 255.0f);
|
||||||
|
|
@ -5352,6 +5416,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||||
int x_patch = img->nx / (params.patch_size);
|
int x_patch = img->nx / (params.patch_size);
|
||||||
|
|
||||||
n_patches += x_patch + 1;
|
n_patches += x_patch + 1;
|
||||||
|
n_patches = 1280;
|
||||||
|
|
||||||
|
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
|
|
@ -5690,14 +5756,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_DEEPSEEKOCR:
|
case PROJECTOR_TYPE_DEEPSEEKOCR:
|
||||||
{
|
{
|
||||||
//FIXME we need correct this when all model configs are set correctly
|
|
||||||
//n_patch is not correct right now
|
|
||||||
int32_t n_pos = 16 * 16 + 1; //hardcode for now
|
|
||||||
std::vector<int32_t> positions(n_pos);
|
|
||||||
for (int i = 0; i < n_pos; i++) {
|
|
||||||
positions[i] = i;
|
|
||||||
}
|
|
||||||
set_input_i32("positions", positions);
|
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_LLAMA4:
|
case PROJECTOR_TYPE_LLAMA4:
|
||||||
{
|
{
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue