From 38b492727e35668a7688761978f3d2a80087f664 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 16 May 2022 13:36:45 +0800 Subject: [PATCH] fix(opr): fix no update ptr in reduce operator when input change GitOrigin-RevId: a443a79ac0a73cd984c163c56771dfc49f9cdebd --- src/opr/impl/basic_arith.cpp | 3 +++ src/opr/test/basic_arith/reduction.cpp | 25 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index fc6c02b0..f7d5479d 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -1648,6 +1648,9 @@ void Reduce::scn_do_execute() { m_kern_scheduler->check_shapes(inp.shape(), out_ptr->shape()); if (m_kern_scheduler->has_actual_computing()) { + m_kern_scheduler->update_ptr( + inp, *out_ptr, + output(1)->shape()[0] ? output(1)->dev_tensor() : DeviceTensorND{}); m_kern_scheduler->execute( static_cast(megdnn_opr()), inp, *out_ptr); } else { diff --git a/src/opr/test/basic_arith/reduction.cpp b/src/opr/test/basic_arith/reduction.cpp index c1bc6840..15baae25 100644 --- a/src/opr/test/basic_arith/reduction.cpp +++ b/src/opr/test/basic_arith/reduction.cpp @@ -416,6 +416,31 @@ TEST(TestBasicArithReduction, NonContFwd) { } } +TEST(TestBasicArithReduction, ResetMemory) { + HostTensorGenerator<> gen; + auto graph = ComputingGraph::make(); + auto host_x = gen({3, 2}); + auto host_tshp = + std::make_shared(host_x->comp_node(), dtype::Int32()); + host_tshp->resize({1}); + host_tshp->ptr()[0] = 1; + + auto tshp = opr::Host2DeviceCopy::make(*graph, host_tshp, {"tshp"}); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + auto y = opr::reduce_max(x, tshp); + + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + func->wait(); + + //! only reset the host x memory, make sure the case can run normal + auto host_x_tmp = gen({3, 2}); + host_x->reset(host_x_tmp->storage(), host_x_tmp->layout()); + func->execute(); + func->wait(); +} + TEST(TestBasicArithReduction, NonContPerform) { DeviceTensorND x{CompNode::default_cpu(), dtype::Float32()}, y{x.comp_node(), x.dtype()}, workspace;