GitOrigin-RevId: f97b3005fd
release-1.5
@@ -776,6 +776,10 @@ void TypeCvt::perform(DeviceTensorND &dest, | |||||
intl::UniqPtrWithCN<megdnn::TypeCvt> &opr) { | intl::UniqPtrWithCN<megdnn::TypeCvt> &opr) { | ||||
mgb_assert(src.comp_node() == opr.comp_node()); | mgb_assert(src.comp_node() == opr.comp_node()); | ||||
mgb_assert(dest_type.valid()); | mgb_assert(dest_type.valid()); | ||||
if (src.empty()) { | |||||
mgb_assert(dest.empty()); | |||||
return; | |||||
} | |||||
if (src.dtype() == dest_type) { | if (src.dtype() == dest_type) { | ||||
dest.copy_from(src); | dest.copy_from(src); | ||||
return; | return; | ||||
@@ -1739,7 +1743,13 @@ void Reduce::record_execute_deps(ExecDependencyArray& deps) { | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(PowC); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(PowC); | ||||
MEGDNN_OPR_CTOR_INIT1(PowC, ssprintf("powc_%g", param.exp)) | |||||
PowC::PowC(VarNode *i0, const Param ¶m, const OperatorNodeConfig &config) | |||||
: Super(OperatorNodeBaseCtorParam{ i0->owner_graph(), config, ssprintf("powc_%g", param.exp), {i0}} ) { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({i0}); | |||||
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
intl::MegDNNOprInitPostCtor<PowC>::apply(*this); | |||||
} | |||||
SymbolVar PowC::make(SymbolVar x, const Param& param, | SymbolVar PowC::make(SymbolVar x, const Param& param, | ||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
@@ -1778,6 +1788,22 @@ void PowC::init_output_static_infer_desc() { | |||||
{SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value}); | {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value}); | ||||
} | } | ||||
void PowC::scn_do_execute() { | |||||
if (input(0)->dev_tensor().empty()) { | |||||
mgb_assert(output(0)->dev_tensor().empty()); | |||||
return; | |||||
} | |||||
mgb_assert(!output(0)->dev_tensor().empty()); | |||||
Super::scn_do_execute(); | |||||
} | |||||
PowC::NodeProp* PowC::do_make_node_prop() const { | |||||
auto ret = Super::do_make_node_prop(); | |||||
ret->add_dep_type_existing_var(input(0), | |||||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
return ret; | |||||
} | |||||
#if MGB_ENABLE_GRAD | #if MGB_ENABLE_GRAD | ||||
MGB_IMPL_OPR_GRAD(PowC) { | MGB_IMPL_OPR_GRAD(PowC) { | ||||
auto exp = opr.param().exp; | auto exp = opr.param().exp; | ||||
@@ -352,6 +352,8 @@ MGB_DEFINE_OPR_CLASS(PowC, intl::MegDNNOprWrapperFwd<megdnn::PowC>) // { | |||||
void add_input_layout_constraint() override; | void add_input_layout_constraint() override; | ||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
void mem_plan_fwd_in2out_writable() override; | void mem_plan_fwd_in2out_writable() override; | ||||
NodeProp* do_make_node_prop() const override; | |||||
void scn_do_execute() override; | |||||
public: | public: | ||||
PowC(VarNode* inp, const Param& param, const OperatorNodeConfig& config); | PowC(VarNode* inp, const Param& param, const OperatorNodeConfig& config); | ||||
@@ -589,6 +589,23 @@ TEST(TestOprBasicArith, TypeCvtFromBool) { | |||||
ASSERT_EQ(TensorShape({2}), host_y.shape()); | ASSERT_EQ(TensorShape({2}), host_y.shape()); | ||||
} | } | ||||
TEST(TestOprBasicArith, TypeCvtPerformEmptyIO) { | |||||
HostTensorGenerator<> gen; | |||||
auto cn = CompNode::load("xpu0"); | |||||
auto host_x = gen({2, 0, 3, 4}); | |||||
auto dev_x = std::make_shared<DeviceTensorND>(cn); | |||||
dev_x->copy_from(*host_x); | |||||
auto dev_y = std::make_shared<DeviceTensorND>(cn, dtype::Int32{}); | |||||
dev_y->resize(dev_x->shape()); | |||||
auto dnn_opr = opr::intl::create_megdnn_opr<megdnn::TypeCvt>(cn); | |||||
ASSERT_NO_THROW(opr::TypeCvt::perform(*dev_y, dtype::Int32{}, *dev_x, dnn_opr)); | |||||
ASSERT_TRUE(dev_y->empty()); | |||||
ASSERT_TRUE(dev_y->shape().is_empty()); | |||||
MGB_ASSERT_SHAPE_EQ(dev_x->shape(), dev_y->shape()); | |||||
} | |||||
TEST(TestOprBasicArith, ElemwiseMemFwd) { | TEST(TestOprBasicArith, ElemwiseMemFwd) { | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
graph->options().graph_opt_level = 0; | graph->options().graph_opt_level = 0; | ||||
@@ -756,4 +773,19 @@ TEST(TestOprBasicArith, PowCInfer) { | |||||
run(true); | run(true); | ||||
} | } | ||||
TEST(TestOprBasicArith, PowCEmptyIO) { | |||||
HostTensorGenerator<> gen; | |||||
auto graph = ComputingGraph::make(); | |||||
// empty input | |||||
auto host_x = gen({4, 0, 2, 3}); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||||
y = opr::PowC::make(x, 3.f); | |||||
HostTensorND host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y)}); | |||||
ASSERT_NO_THROW(func->execute().wait()); | |||||
ASSERT_TRUE(host_y.empty()); | |||||
ASSERT_TRUE(host_y.shape().is_empty()); | |||||
MGB_ASSERT_SHAPE_EQ(host_x->shape(), host_y.shape()); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |