This commit is contained in:
Chenguang Li 2026-03-20 23:37:39 +09:00 committed by GitHub
commit 36e8a4b7f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 81 additions and 1 deletions

View File

@ -3011,6 +3011,58 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
}
}
void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
GGML_TENSOR_UNARY_OP_LOCALS
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int) * 4);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_imrope || mrope_used) {
is_neox = true;
}
int64_t rope_dims = n_dims;
if (is_vision) {
rope_dims = src0->ne[0];
}
// Run the full cache init on the non-captured stream. This performs all
// host-to-device memcpy, aclrtMalloc/Free, and on-device computations
// so that the memory pool is warmed up and cache metadata is populated.
aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections,
mrope_used, is_imrope, is_vision, rope_dims);
// Reset `cached` so that during graph capture the on-device computations
// (sin/cos, position multiply, repeat, etc.) still execute and get recorded
// into the captured graph. The cache metadata (theta_scale_length,
// theta_scale, sections, position_length, etc.) remains set, which causes
// all host-to-device copy and malloc/free branches to be skipped.
ctx.rope_cache.cached = false;
}
void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];

View File

@ -543,6 +543,21 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
*/
void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Pre-load the RoPE cache before ACL graph capture.
*
* This function must be called outside of graph capture to perform
* host-to-device memory copies and device memory allocations that are
* not allowed on a captured stream. After pre-loading, the rope cache
* metadata is updated so that the subsequent call to
* aclnn_rope_cache_init (inside graph capture) skips these operations
* and only records the on-device computations into the captured graph.
*
* @param ctx CANN backend context.
* @param dst A ROPE destination tensor from the computation graph.
*/
void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the index of the maximum value along the specified dimension
* of a ggml tensor using the CANN backend.

View File

@ -277,7 +277,7 @@ struct ggml_graph_node_properties {
}
}
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU || node->op == GGML_OP_ROPE){
return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
}
return true;

View File

@ -2225,6 +2225,19 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
// If no matching graph is found, add a new ACL graph.
ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
cann_ctx->graph_lru_cache.push(new_graph);
// Pre-load rope cache before graph capture. During capture the
// stream cannot perform host-to-device memcpy or device memory
// malloc/free. Running the full cache init now populates the
// cache metadata so these branches are skipped during capture,
// while also warming up the memory pool.
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (node->op == GGML_OP_ROPE) {
ggml_cann_rope_cache_preload(*cann_ctx, node);
break;
}
}
}
}
#else