diff --git a/src/opr/impl/atlas_runtime_op.cpp b/src/opr/impl/atlas_runtime_op.cpp index 7853545a..46d3e437 100644 --- a/src/opr/impl/atlas_runtime_op.cpp +++ b/src/opr/impl/atlas_runtime_op.cpp @@ -278,6 +278,9 @@ void AtlasRuntimeOpr::scn_do_execute() { for (size_t i = 0; i < output().size(); i++) { auto output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i); auto ovar = output(i); + output_size = std::max( + output_size, + ovar->dtype().size(ovar->shape().total_nr_elems())); ovar->shape_alloc(ovar->shape(), output_size); } } diff --git a/src/opr/test/atlas_runtime_op.cpp b/src/opr/test/atlas_runtime_op.cpp index 19ab0f24..b837da44 100644 --- a/src/opr/test/atlas_runtime_op.cpp +++ b/src/opr/test/atlas_runtime_op.cpp @@ -65,7 +65,7 @@ TEST(TestOprAtlas, Basic) { } TEST(TestOprAtlas, DynamicBatch) { - for (size_t batch : {1, 6}) { + for (size_t batch : {1, 6, 20}) { HostTensorGenerator<> gen; const auto& graph = ComputingGraph::make(); const auto& host_x = gen({batch, 3, 16, 16});