removed all ggml_cont b4 ggml_reshape_4d
This commit is contained in:
parent
a6b2c450c8
commit
6216273ede
|
|
@ -540,9 +540,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_chunkin
|
|||
// decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
|
||||
decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
|
||||
|
||||
ggml_tensor * k_i = ggml_cont(ctx0, ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB));
|
||||
ggml_tensor * k_j = ggml_cont(ctx0, ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB));
|
||||
ggml_tensor * q_i = ggml_cont(ctx0, ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB));
|
||||
ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB);
|
||||
ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
|
||||
ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB);
|
||||
|
||||
ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i);
|
||||
ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i);
|
||||
|
|
|
|||
Loading…
Reference in New Issue