diff --git a/ge/host_kernels/dynamic_stitch_kernel.cc b/ge/host_kernels/dynamic_stitch_kernel.cc index 3037934e..52f6cdcf 100644 --- a/ge/host_kernels/dynamic_stitch_kernel.cc +++ b/ge/host_kernels/dynamic_stitch_kernel.cc @@ -111,8 +111,9 @@ void DynamicStitchKernel::ComputeMergedShape(const vector &inp int32_t merged_first_dim = 0; int64_t indices_shape_size = 0; for (int i = 0; i < n_; i++) { - indices_shape_size = input[i]->GetTensorDesc().GetShape().GetShapeSize(); - indices_shape_size = indices_shape_size == 0 ? 1 : indices_shape_size; + // shape is [] means scalar + indices_shape_size = + input[i]->GetTensorDesc().GetShape().GetDims().empty() ? 1 : input[i]->GetTensorDesc().GetShape().GetShapeSize(); const int32_t *input_indices = reinterpret_cast(input[i]->GetData().data()); for (int64_t j = 0; j < indices_shape_size; j++) { merged_first_dim = std::max(merged_first_dim, input_indices[j]);