Browse Source

fix(mgb/atlas): when batchsize more than model max batchsize

GitOrigin-RevId: 63fe79eaa9
release-1.2
Megvii Engine Team 4 years ago
parent
commit
e92670e820
2 changed files with 4 additions and 1 deletions
  1. +3
    -0
      src/opr/impl/atlas_runtime_op.cpp
  2. +1
    -1
      src/opr/test/atlas_runtime_op.cpp

+ 3
- 0
src/opr/impl/atlas_runtime_op.cpp View File

@@ -278,6 +278,9 @@ void AtlasRuntimeOpr::scn_do_execute() {
for (size_t i = 0; i < output().size(); i++) {
auto output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i);
auto ovar = output(i);
output_size = std::max<size_t>(
output_size,
ovar->dtype().size(ovar->shape().total_nr_elems()));
ovar->shape_alloc(ovar->shape(), output_size);
}
}


+ 1
- 1
src/opr/test/atlas_runtime_op.cpp View File

@@ -65,7 +65,7 @@ TEST(TestOprAtlas, Basic) {
}

TEST(TestOprAtlas, DynamicBatch) {
for (size_t batch : {1, 6}) {
for (size_t batch : {1, 6, 20}) {
HostTensorGenerator<> gen;
const auto& graph = ComputingGraph::make();
const auto& host_x = gen({batch, 3, 16, 16});


Loading…
Cancel
Save