add tensor type checking as part of cuda graph properties (#19186)
This commit is contained in:
parent
1025fd2c09
commit
ecbf01d441
|
|
@ -1124,6 +1124,7 @@ struct ggml_tensor_extra_gpu {
|
||||||
struct ggml_cuda_graph_node_properties {
|
struct ggml_cuda_graph_node_properties {
|
||||||
void * node_data;
|
void * node_data;
|
||||||
ggml_op node_op;
|
ggml_op node_op;
|
||||||
|
enum ggml_type node_type;
|
||||||
int32_t flags;
|
int32_t flags;
|
||||||
int64_t ne[GGML_MAX_DIMS];
|
int64_t ne[GGML_MAX_DIMS];
|
||||||
size_t nb[GGML_MAX_DIMS];
|
size_t nb[GGML_MAX_DIMS];
|
||||||
|
|
|
||||||
|
|
@ -2920,6 +2920,7 @@ static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties
|
||||||
memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
|
memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
|
||||||
props->node_data = node->data;
|
props->node_data = node->data;
|
||||||
props->node_op = node->op;
|
props->node_op = node->op;
|
||||||
|
props->node_type = node->type;
|
||||||
props->flags = node->flags;
|
props->flags = node->flags;
|
||||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||||
props->ne[i] = node->ne[i];
|
props->ne[i] = node->ne[i];
|
||||||
|
|
@ -2944,6 +2945,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node->type != props->node_type) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||||
if (node->ne[i] != props->ne[i]) {
|
if (node->ne[i] != props->ne[i]) {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue