Merge 0fcc4cc8ba into 58c81f7e81
This commit is contained in:
commit
36e8a4b7f7
|
|
@ -3011,6 +3011,58 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
int sections[4];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int) * 4);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_imrope || mrope_used) {
|
||||
is_neox = true;
|
||||
}
|
||||
|
||||
int64_t rope_dims = n_dims;
|
||||
if (is_vision) {
|
||||
rope_dims = src0->ne[0];
|
||||
}
|
||||
|
||||
// Run the full cache init on the non-captured stream. This performs all
|
||||
// host-to-device memcpy, aclrtMalloc/Free, and on-device computations
|
||||
// so that the memory pool is warmed up and cache metadata is populated.
|
||||
aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections,
|
||||
mrope_used, is_imrope, is_vision, rope_dims);
|
||||
|
||||
// Reset `cached` so that during graph capture the on-device computations
|
||||
// (sin/cos, position multiply, repeat, etc.) still execute and get recorded
|
||||
// into the captured graph. The cache metadata (theta_scale_length,
|
||||
// theta_scale, sections, position_length, etc.) remains set, which causes
|
||||
// all host-to-device copy and malloc/free branches to be skipped.
|
||||
ctx.rope_cache.cached = false;
|
||||
}
|
||||
|
||||
void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
|
|
|
|||
|
|
@ -543,6 +543,21 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
|||
*/
|
||||
void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Pre-load the RoPE cache before ACL graph capture.
|
||||
*
|
||||
* This function must be called outside of graph capture to perform
|
||||
* host-to-device memory copies and device memory allocations that are
|
||||
* not allowed on a captured stream. After pre-loading, the rope cache
|
||||
* metadata is updated so that the subsequent call to
|
||||
* aclnn_rope_cache_init (inside graph capture) skips these operations
|
||||
* and only records the on-device computations into the captured graph.
|
||||
*
|
||||
* @param ctx CANN backend context.
|
||||
* @param dst A ROPE destination tensor from the computation graph.
|
||||
*/
|
||||
void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the index of the maximum value along the specified dimension
|
||||
* of a ggml tensor using the CANN backend.
|
||||
|
|
|
|||
|
|
@ -277,7 +277,7 @@ struct ggml_graph_node_properties {
|
|||
}
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
|
||||
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU || node->op == GGML_OP_ROPE){
|
||||
return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
|
||||
}
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -2225,6 +2225,19 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
|||
// If no matching graph is found, add a new ACL graph.
|
||||
ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
|
||||
cann_ctx->graph_lru_cache.push(new_graph);
|
||||
|
||||
// Pre-load rope cache before graph capture. During capture the
|
||||
// stream cannot perform host-to-device memcpy or device memory
|
||||
// malloc/free. Running the full cache init now populates the
|
||||
// cache metadata so these branches are skipped during capture,
|
||||
// while also warming up the memory pool.
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
if (node->op == GGML_OP_ROPE) {
|
||||
ggml_cann_rope_cache_preload(*cann_ctx, node);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
|
|
|||
Loading…
Reference in New Issue