CUDA + openCL: fix bug in accessing rms_norm->src while doing fusion (#16577)
This commit is contained in:
parent
4258e0cfe7
commit
120bf7046d
|
|
@ -2876,7 +2876,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||||
}
|
}
|
||||||
|
|
||||||
//if rms norm is the B operand, then we don't handle broadcast
|
//if rms norm is the B operand, then we don't handle broadcast
|
||||||
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2686,7 +2686,7 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
|
||||||
|
|
||||||
// if rms_norm is the B operand, then we don't handle broadcast
|
// if rms_norm is the B operand, then we don't handle broadcast
|
||||||
if (rms_norm == mul->src[1] &&
|
if (rms_norm == mul->src[1] &&
|
||||||
!ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
!ggml_are_same_shape(mul->src[0], rms_norm)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue