Browse Source

feat(mgb/opr): let PowC & TypeCvt support empty IO

GitOrigin-RevId: f97b3005fd
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
dea5278172
3 changed files with 61 additions and 1 deletions
  1. +27
    -1
      src/opr/impl/basic_arith.cpp
  2. +2
    -0
      src/opr/include/megbrain/opr/basic_arith.h
  3. +32
    -0
      src/opr/test/basic_arith/others.cpp

+ 27
- 1
src/opr/impl/basic_arith.cpp View File

@@ -776,6 +776,10 @@ void TypeCvt::perform(DeviceTensorND &dest,
intl::UniqPtrWithCN<megdnn::TypeCvt> &opr) {
mgb_assert(src.comp_node() == opr.comp_node());
mgb_assert(dest_type.valid());
if (src.empty()) {
mgb_assert(dest.empty());
return;
}
if (src.dtype() == dest_type) {
dest.copy_from(src);
return;
@@ -1739,7 +1743,13 @@ void Reduce::record_execute_deps(ExecDependencyArray& deps) {

MGB_DYN_TYPE_OBJ_FINAL_IMPL(PowC);

MEGDNN_OPR_CTOR_INIT1(PowC, ssprintf("powc_%g", param.exp))
PowC::PowC(VarNode *i0, const Param &param, const OperatorNodeConfig &config)
: Super(OperatorNodeBaseCtorParam{ i0->owner_graph(), config, ssprintf("powc_%g", param.exp), {i0}} ) {
init_megdnn_opr(*this, param);
add_input({i0});
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
intl::MegDNNOprInitPostCtor<PowC>::apply(*this);
}

SymbolVar PowC::make(SymbolVar x, const Param& param,
const OperatorNodeConfig& config) {
@@ -1778,6 +1788,22 @@ void PowC::init_output_static_infer_desc() {
{SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value});
}

void PowC::scn_do_execute() {
if (input(0)->dev_tensor().empty()) {
mgb_assert(output(0)->dev_tensor().empty());
return;
}
mgb_assert(!output(0)->dev_tensor().empty());
Super::scn_do_execute();
}

PowC::NodeProp* PowC::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}

#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(PowC) {
auto exp = opr.param().exp;


+ 2
- 0
src/opr/include/megbrain/opr/basic_arith.h View File

@@ -352,6 +352,8 @@ MGB_DEFINE_OPR_CLASS(PowC, intl::MegDNNOprWrapperFwd<megdnn::PowC>) // {
void add_input_layout_constraint() override;
void init_output_static_infer_desc() override;
void mem_plan_fwd_in2out_writable() override;
NodeProp* do_make_node_prop() const override;
void scn_do_execute() override;

public:
PowC(VarNode* inp, const Param& param, const OperatorNodeConfig& config);


+ 32
- 0
src/opr/test/basic_arith/others.cpp View File

@@ -589,6 +589,23 @@ TEST(TestOprBasicArith, TypeCvtFromBool) {
ASSERT_EQ(TensorShape({2}), host_y.shape());
}

TEST(TestOprBasicArith, TypeCvtPerformEmptyIO) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("xpu0");
auto host_x = gen({2, 0, 3, 4});
auto dev_x = std::make_shared<DeviceTensorND>(cn);
dev_x->copy_from(*host_x);

auto dev_y = std::make_shared<DeviceTensorND>(cn, dtype::Int32{});
dev_y->resize(dev_x->shape());

auto dnn_opr = opr::intl::create_megdnn_opr<megdnn::TypeCvt>(cn);
ASSERT_NO_THROW(opr::TypeCvt::perform(*dev_y, dtype::Int32{}, *dev_x, dnn_opr));
ASSERT_TRUE(dev_y->empty());
ASSERT_TRUE(dev_y->shape().is_empty());
MGB_ASSERT_SHAPE_EQ(dev_x->shape(), dev_y->shape());
}

TEST(TestOprBasicArith, ElemwiseMemFwd) {
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
@@ -756,4 +773,19 @@ TEST(TestOprBasicArith, PowCInfer) {
run(true);
}

TEST(TestOprBasicArith, PowCEmptyIO) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
// empty input
auto host_x = gen({4, 0, 2, 3});
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::PowC::make(x, 3.f);
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
ASSERT_NO_THROW(func->execute().wait());
ASSERT_TRUE(host_y.empty());
ASSERT_TRUE(host_y.shape().is_empty());
MGB_ASSERT_SHAPE_EQ(host_x->shape(), host_y.shape());
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save