Browse Source

fix(mgb/opr): fix Reduce static value inference

GitOrigin-RevId: 5e5c56064c
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
99b176236b
3 changed files with 42 additions and 5 deletions
  1. +5
    -3
      src/opr/impl/basic_arith.cpp
  2. +1
    -0
      src/opr/include/megbrain/opr/basic_arith.h
  3. +36
    -2
      src/opr/test/basic_arith/reduction.cpp

+ 5
- 3
src/opr/impl/basic_arith.cpp View File

@@ -1501,8 +1501,9 @@ void Reduce::init_output_static_infer_desc() {
auto infer_value = [this](DeviceTensorND &dest, const InpVal &inp) {
DeviceTensorND workspace;
auto sopr = static_infer_opr.lock();
perform(m_param.mode, dest, workspace,
inp.val[0].value(), inp.val.at(1).shape(), sopr(), m_param.data_type);
perform(m_param.mode, dest, workspace, inp.val[0].value(),
output(0)->dtype(), inp.val.at(1).shape(), sopr(),
m_param.data_type);
return true;
};

@@ -1632,6 +1633,7 @@ void Reduce::perform(
Mode mode,
DeviceTensorND &dest, DeviceTensorND &workspace,
const DeviceTensorND &input,
const DType &target_dtype,
const TensorShape &target_shape,
intl::UniqPtrWithCN<megdnn::Reduce> &opr, const Param::DataType data_type) {

@@ -1674,7 +1676,7 @@ void Reduce::perform(
}

opr.comp_node().activate();
dest.comp_node(opr.comp_node()).dtype(input.dtype()).resize(target_shape);
dest.comp_node(opr.comp_node()).dtype(target_dtype).resize(target_shape);
ksched.update_ptr(*input_contig, dest, workspace);
ksched.execute(opr.get(), *input_contig, dest);
}


+ 1
- 0
src/opr/include/megbrain/opr/basic_arith.h View File

@@ -304,6 +304,7 @@ MGB_DEFINE_OPR_CLASS(Reduce, intl::DynamicOutputIfInputDynamic<
static void perform(Mode mode, DeviceTensorND& dest,
DeviceTensorND& workspace,
const DeviceTensorND& input,
const DType& target_dtype,
const TensorShape& target_shape,
intl::UniqPtrWithCN<megdnn::Reduce>& opr,
const Param::DataType data_type=Param::DataType::DEFAULT);


+ 36
- 2
src/opr/test/basic_arith/reduction.cpp View File

@@ -298,7 +298,8 @@ namespace {
static_calc_x.copy_from(*host_x);
opr::Reduce::perform(
Mode::SUM, static_calc_y, static_calc_workspace,
static_calc_x, oshp, static_calc_opr);
static_calc_x, dtype::Float32(), oshp,
static_calc_opr);
host_y.ptr<float>()[0] ++;
host_y.copy_from(static_calc_y);
MGB_ASSERT_TENSOR_NEAR(expected, host_y, 1e-5);
@@ -468,7 +469,8 @@ TEST(TestBasicArithReduction, NonContPerform) {
for (auto &&tshp:
TensorShapeArray{{5, 1}, {1, 5}, {1, 1}, {1}, {5, 5}}) {

opr::Reduce::perform(mode, y, workspace, x, tshp, opr);
opr::Reduce::perform(mode, y, workspace, x, dtype::Float32(), tshp,
opr);
ASSERT_TRUE(y.layout().is_contiguous());
ASSERT_EQ(tshp, y.shape());
size_t nr = tshp.total_nr_elems();
@@ -866,4 +868,36 @@ TEST(TestBasicArithReduction, StaticInferValue) {
MGB_ASSERT_TENSOR_EQ(inferred, expected);
}

TEST(TestBasicArithReduction, StaticInferValueDType) {
using ParamType = opr::Reduce::Param::DataType;
DType F32 = dtype::Float32(), F16 = dtype::Float16();

auto run_test = [](const DType& itype, const DType& expected_otype,
ParamType param_dtype) {
HostTensorGenerator<> gen;
auto host_x = gen({2, 3, 4, 5});
auto host_tshp = std::make_shared<HostTensorND>(host_x->comp_node(),
dtype::Int32());
host_tshp->resize({1});
host_tshp->ptr<int>()[0] = 1;

auto graph = ComputingGraph::make();
auto x_f32 = opr::Host2DeviceCopy::make(*graph, host_x),
x = opr::TypeCvt::make(x_f32, itype),
tshp = opr::Host2DeviceCopy::make(*graph, host_tshp),
y = opr::Reduce::make(
x, {opr::Reduce::Mode::SUM, MEGDNN_MAX_NDIM, param_dtype},
tshp);
auto inferred = graph->static_infer_manager().infer_value(y.node());
ASSERT_EQ(inferred.layout().dtype, expected_otype);
};

run_test(F32, F32, ParamType::DEFAULT);
run_test(F16, F16, ParamType::DEFAULT);
run_test(F32, F32, ParamType::FLOAT_O32xC32);
run_test(F16, F32, ParamType::FLOAT_O32xC32);
run_test(F32, F16, ParamType::FLOAT_O16xC32);
run_test(F16, F16, ParamType::FLOAT_O16xC32);
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save