llama.cpp/tools/mtmd/swin.h

153 lines
5.3 KiB
C++

#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#include "clip-impl.h"
#include <vector>
#include <string>
// Swin Transformer constants
#define KEY_SWIN_WINDOW_SIZE "swin.window_size"
#define KEY_SWIN_PATCH_SIZE "swin.patch_size"
#define KEY_SWIN_IMAGE_SIZE "swin.image_size"
#define KEY_SWIN_DEPTHS "swin.depths"
#define KEY_SWIN_NUM_HEADS "swin.num_heads"
#define KEY_SWIN_HIDDEN_DIM "swin.hidden_dim"
#define KEY_SWIN_NUM_CHANNELS "swin.num_channels"
#define KEY_SWIN_MLP_RATIO "swin.mlp_ratio"
#define KEY_SWIN_DROP_PATH_RATE "swin.drop_path_rate"
#define KEY_SWIN_NORM_EPS "swin.norm_eps"
// Tensor names for Swin Transformer
#define TN_SWIN_PATCH_EMBED "swin.patch_embed.weight"
#define TN_SWIN_PATCH_NORM "swin.patch_embed.norm.%s"
#define TN_SWIN_POS_EMBED "swin.pos_embed"
#define TN_SWIN_DOWNSAMPLE_NORM "swin.stage.%d.downsample.norm.%s"
#define TN_SWIN_DOWNSAMPLE_PROJ "swin.stage.%d.downsample.reduction.weight"
#define TN_SWIN_ATTN_NORM "swin.stage.%d.layer.%d.norm1.%s"
#define TN_SWIN_ATTN_QKV "swin.stage.%d.layer.%d.attn.qkv.%s"
#define TN_SWIN_ATTN_PROJ "swin.stage.%d.layer.%d.attn.proj.%s"
#define TN_SWIN_ATTN_REL_POS "swin.stage.%d.layer.%d.attn.relative_position_bias_table"
#define TN_SWIN_FFN_NORM "swin.stage.%d.layer.%d.norm2.%s"
#define TN_SWIN_FFN_FC1 "swin.stage.%d.layer.%d.mlp.fc1.%s"
#define TN_SWIN_FFN_FC2 "swin.stage.%d.layer.%d.mlp.fc2.%s"
#define TN_SWIN_OUTPUT_NORM "swin.norm.%s"
// Forward declarations
struct swin_ctx;
// Swin Transformer hyperparameters
struct swin_hparams {
int32_t image_size = 384;
int32_t patch_size = 4;
int32_t num_channels = 3;
int32_t window_size = 7;
int32_t hidden_dim = 96;
std::vector<int32_t> depths = {2, 2, 6, 2}; // depths for each stage
std::vector<int32_t> num_heads = {3, 6, 12, 24}; // number of heads for each stage
float mlp_ratio = 4.0f;
float drop_path_rate = 0.1f;
float norm_eps = 1e-5f;
bool use_checkpoint = false;
// Computed values
int32_t num_stages() const { return depths.size(); }
int32_t num_patches() const { return (image_size / patch_size) * (image_size / patch_size); }
};
// Swin Transformer layer
struct swin_layer {
// Window attention
struct ggml_tensor * ln1_w;
struct ggml_tensor * ln1_b;
struct ggml_tensor * qkv_w;
struct ggml_tensor * qkv_b;
struct ggml_tensor * proj_w;
struct ggml_tensor * proj_b;
struct ggml_tensor * relative_position_bias_table;
// FFN
struct ggml_tensor * ln2_w;
struct ggml_tensor * ln2_b;
struct ggml_tensor * fc1_w;
struct ggml_tensor * fc1_b;
struct ggml_tensor * fc2_w;
struct ggml_tensor * fc2_b;
};
// Swin Transformer stage
struct swin_stage {
std::vector<swin_layer> layers;
// Patch merging (downsample) layer
struct ggml_tensor * downsample_norm_w = nullptr;
struct ggml_tensor * downsample_norm_b = nullptr;
struct ggml_tensor * downsample_reduction = nullptr;
};
// Swin Transformer vision model
struct swin_vision_model {
swin_hparams hparams;
// Patch embedding
struct ggml_tensor * patch_embed;
struct ggml_tensor * patch_norm_w;
struct ggml_tensor * patch_norm_b;
struct ggml_tensor * pos_embed;
// Stages
std::vector<swin_stage> stages;
// Output norm
struct ggml_tensor * output_norm_w;
struct ggml_tensor * output_norm_b;
};
// Main Swin context
struct swin_ctx {
bool has_vision_encoder = false;
bool has_projector = false;
swin_vision_model vision_model;
// Backend and compute
struct ggml_backend * backend = nullptr;
ggml_backend_buffer_t params_buffer = nullptr;
struct ggml_context * ctx = nullptr;
std::vector<uint8_t> buf_compute_meta;
// GGML compute resources
struct ggml_backend_buffer * compute_buffer = nullptr;
struct ggml_context * ctx_compute = nullptr;
struct ggml_alloc * compute_alloc = nullptr;
};
// Public API functions
struct swin_ctx * swin_model_load(const std::string & fname, int verbosity = 1);
void swin_free(struct swin_ctx * ctx);
// Build Swin Transformer graph for inference
struct ggml_cgraph * swin_build_graph(
struct swin_ctx * ctx,
const swin_image_batch * imgs,
std::pair<int, int> load_image_size = {0, 0},
bool is_inf = false);
// Encode image batch
bool swin_image_batch_encode(
struct swin_ctx * ctx,
int n_threads,
const swin_image_batch * imgs,
float * vec);
// Utility functions
int swin_patch_size(const struct swin_ctx * ctx);
bool swin_image_preprocess(struct swin_ctx * ctx, const swin_image_u8 * img, swin_image_f32 * res);
bool swin_image_batch_preprocess(struct swin_ctx * ctx, int n_threads, const swin_image_batch * imgs, swin_image_f32_batch * res_batch);
// Window operations for Swin Transformer
struct ggml_tensor * swin_window_partition(struct ggml_context * ctx, struct ggml_tensor * x, int window_size);
struct ggml_tensor * swin_window_reverse(struct ggml_context * ctx, struct ggml_tensor * windows, int window_size, int H, int W);
struct ggml_tensor * swin_create_window_mask(struct ggml_context * ctx, int window_size, int shift_size, int H, int W);
struct ggml_tensor * swin_compute_mask(struct ggml_context * ctx, int window_size, int shift_size, int H, int W);