diff --git a/dnn/src/common/reduce.cpp b/dnn/src/common/reduce.cpp index 99ec0bab..9b5e6d5b 100644 --- a/dnn/src/common/reduce.cpp +++ b/dnn/src/common/reduce.cpp @@ -78,8 +78,10 @@ void ReduceForward::check_exec(const TensorLayout& src, const TensorLayout& dst, megdnn_assert(dst.shape[i] == 1_z, "%s", errmsg().c_str()); } } - megdnn_assert(src.dtype.category() == dst.dtype.category(), - "the category of reduce output and input must be the same"); + megdnn_assert(src.dtype.category() == dst.dtype.category() || + param().data_type == Reduce::DataType::FLOAT_O32xC32, + "the category of reduce output and input must be the same," + " or the data_type is FLOAT_O32xC32"); if (param().data_type == DataType::DEFAULT) { megdnn_assert(src.dtype == dst.dtype && (src.dtype.category() == DTypeCategory::FLOAT || @@ -89,8 +91,11 @@ void ReduceForward::check_exec(const TensorLayout& src, const TensorLayout& dst, megdnn_assert(src.dtype.enumv() == DTypeEnum::Quantized8Asymm); } else if (param().data_type == DataType::QINT_I8xO32) { megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8); - } else { + } else if (param().data_type == DataType::FLOAT_IO16xC32 || + param().data_type == DataType::FLOAT_O16xC32) { megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); + } else { + megdnn_assert(param().data_type == DataType::FLOAT_O32xC32); } auto expected = get_out_dtype(param().data_type, src.dtype); diff --git a/dnn/src/cuda/reduce/opr_impl.cpp b/dnn/src/cuda/reduce/opr_impl.cpp index b1ec8b54..ca677c7e 100644 --- a/dnn/src/cuda/reduce/opr_impl.cpp +++ b/dnn/src/cuda/reduce/opr_impl.cpp @@ -28,6 +28,7 @@ size_t dispatch_dtype_workspace(const TensorLayout& src, const TensorLayout&, Reduce::DataType data_type) { using f16 = DTypeTrait::ctype; using f32 = DTypeTrait::ctype; + using i32 = DTypeTrait::ctype; if (data_type == Reduce::DataType::DEFAULT) { #define cb(_dt) \ case DTypeTrait<_dt>::enumv: { \ @@ -46,6 +47,8 @@ size_t dispatch_dtype_workspace(const TensorLayout& src, const TensorLayout&, return get_reduce_workspace_in_bytes>(A, B, C); else if (src.dtype == dtype::Float32()) return get_reduce_workspace_in_bytes>(A, B, C); + else if (src.dtype == dtype::Int32()) + return get_reduce_workspace_in_bytes>(A, B, C); } else if (data_type == Reduce::DataType::FLOAT_O16xC32) { if (src.dtype == dtype::Float16()) return get_reduce_workspace_in_bytes>(A, B, C); @@ -61,6 +64,7 @@ void dispatch_dtype(cudaStream_t stream, const TensorND& src, size_t B, size_t C, Reduce::DataType data_type) { using f16 = DTypeTrait::ctype; using f32 = DTypeTrait::ctype; + using i32 = DTypeTrait::ctype; if (data_type == Reduce::DataType::DEFAULT) { switch (src.layout.dtype.enumv()) { #define cb(_dt) \ @@ -80,10 +84,14 @@ void dispatch_dtype(cudaStream_t stream, const TensorND& src, return run_reduce, false>( workspace.ptr(), A, B, C, stream, Op(src.ptr(), dst.ptr(), B)); - } else { + } else if (src.layout.dtype == dtype::Float32()) { return run_reduce, false>( workspace.ptr(), A, B, C, stream, Op(src.ptr(), dst.ptr(), B)); + } else if (src.layout.dtype == dtype::Int32()) { + return run_reduce, false>( + workspace.ptr(), A, B, C, stream, + Op(src.ptr(), dst.ptr(), B)); } } else if (data_type == Reduce::DataType::FLOAT_O16xC32) { if (src.layout.dtype == dtype::Float16()) { diff --git a/dnn/src/cuda/reduce/reduce.cu b/dnn/src/cuda/reduce/reduce.cu index 02a48b8b..23dcc39c 100644 --- a/dnn/src/cuda/reduce/reduce.cu +++ b/dnn/src/cuda/reduce/reduce.cu @@ -36,6 +36,7 @@ MEGDNN_FOREACH_COMPUTING_DTYPE(cb) INST(dt_float16, dt_float16, float) INST(dt_float16, float, float) INST(float, dt_float16, float) +INST(int, float, float) #undef cb #undef INST diff --git a/dnn/test/cuda/reduce.cpp b/dnn/test/cuda/reduce.cpp index 9e863f29..a3eee5dc 100644 --- a/dnn/test/cuda/reduce.cpp +++ b/dnn/test/cuda/reduce.cpp @@ -80,6 +80,8 @@ TEST_F(CUDA, REDUCE) { } check(mode, dtype::Float16(), dtype::Float32(), Reduce::DataType::FLOAT_O32xC32); + check(mode, dtype::Int32(), dtype::Float32(), + Reduce::DataType::FLOAT_O32xC32); check(mode, dtype::Float16(), dtype::Float16(), Reduce::DataType::FLOAT_O16xC32); check(mode, dtype::Float32(), dtype::Float16(), diff --git a/dnn/test/fallback/reduce.cpp b/dnn/test/fallback/reduce.cpp index 840d736b..72f075fa 100644 --- a/dnn/test/fallback/reduce.cpp +++ b/dnn/test/fallback/reduce.cpp @@ -50,6 +50,10 @@ TEST_F(FALLBACK, REDUCE) { param.data_type = DataType::FLOAT_O32xC32; config = Config(param, dtype, shape); configs.push_back(config); + } else if (dtype == dtype::Int32()) { + Param param(mode, axis, DataType::FLOAT_O32xC32); + Config config(param, dtype, shape); + configs.push_back(config); } } // large (ABC) -> (A1C) case diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 20cad2aa..489d9689 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -1680,7 +1680,7 @@ void Reduce::create_megdnn_opr() { MGB_IMPL_OPR_GRAD(Reduce) { for (size_t i = 1; i < opr.output().size(); ++ i) mgb_assert(!out_grad[i]); - if (wrt_idx) + if (wrt_idx || opr.input(0)->dtype().category() != DTypeCategory::FLOAT) return InvalidGrad::make(opr, wrt_idx); SymbolVar og{out_grad[0]}, iv{opr.input(0)}, ov{opr.output(0)}; constexpr auto cmv = Elemwise::Mode::COND_LEQ_MOV; @@ -1700,8 +1700,9 @@ MGB_IMPL_OPR_GRAD(Reduce) { case Mode::MEAN: { auto og_shape = opr::GetVarShape::make(og), iv_shape = opr::GetVarShape::make(iv), - scale = opr::reduce_prod(og_shape, og_shape.make_scalar(1)) / - opr::reduce_prod(iv_shape, iv_shape.make_scalar(1)); + scale = div( + opr::reduce_prod(og_shape, og_shape.make_scalar(1)), + opr::reduce_prod(iv_shape, iv_shape.make_scalar(1))); return scale * Broadcast::make(og, GetVarShape::make(iv)); } default: diff --git a/src/opr/test/basic_arith/reduction.cpp b/src/opr/test/basic_arith/reduction.cpp index ca0286f3..6b6ea09e 100644 --- a/src/opr/test/basic_arith/reduction.cpp +++ b/src/opr/test/basic_arith/reduction.cpp @@ -27,6 +27,7 @@ using namespace mgb; namespace { using Mode = opr::Reduce::Mode; + using DataType = opr::Reduce::Param::DataType; template struct ImplTrait { @@ -43,6 +44,10 @@ namespace { static ctype reduce(ctype accum, ctype v) { return accum + v; } + + ctype finalize(ctype result) { + return result; + } }; template @@ -56,6 +61,10 @@ namespace { static ctype reduce(ctype accum, ctype v) { return accum + v * v; } + + ctype finalize(ctype result) { + return result; + } }; template @@ -69,6 +78,10 @@ namespace { static ctype reduce(ctype accum, ctype v) { return accum * v; } + + ctype finalize(ctype result) { + return result; + } }; template @@ -82,6 +95,10 @@ namespace { static ctype reduce(ctype accum, ctype v) { return std::max(accum, v); } + + ctype finalize(ctype result) { + return result; + } }; template @@ -95,6 +112,30 @@ namespace { static ctype reduce(ctype accum, ctype v) { return std::min(accum, v); } + + ctype finalize(ctype result) { + return result; + } + }; + + template + struct ImplTrait { + static constexpr float GRAD_MAXERR = 1e-4, GRAD_EPS = 1e-2; + size_t nr_elems; + + ctype init() { + nr_elems = 0; + return 0; + } + + ctype reduce(ctype accum, ctype v) { + nr_elems ++; + return accum + v; + } + + ctype finalize(ctype result) { + return result / static_cast(nr_elems); + } }; template @@ -108,10 +149,11 @@ namespace { return; } - ctype val = Impl::init(); + Impl impl; + ctype val = impl.init(); for (auto i: megdnn::tensor_iter_valonly(src.as_megdnn())) - val = Impl::reduce(val, i); - dest.ptr()[0] = val; + val = impl.reduce(val, i); + dest.ptr()[0] = impl.finalize(val); return; } @@ -143,15 +185,16 @@ namespace { for (size_t i = 0; i < tshp.ndim; i ++) offset += iter.idx()[i] * src.layout().stride[i]; - ctype val = Impl::init(); + Impl impl; + ctype val = impl.init(); auto subspec = SubTensorSpec::make_from_offset_elem( sub_layout, offset); HostTensorND subt = const_cast(src).sub(subspec); for (ctype i: megdnn::tensor_iter_valonly(subt.as_megdnn())) { - val = Impl::reduce(val, i); + val = impl.reduce(val, i); } - *iter = val; + *iter = impl.finalize(val); } } @@ -535,7 +578,7 @@ TEST(TestBasicArithReduction, DifferentNDim) { for (auto mode : {Reduce::Mode::PRODUCT, Reduce::Mode::MAX, Reduce::Mode::MIN, - Reduce::Mode::SUM, Reduce::Mode::SUM_SQR}) { + Reduce::Mode::SUM, Reduce::Mode::SUM_SQR, Reduce::Mode::MEAN}) { check_mode(mode); } } @@ -606,7 +649,7 @@ TEST(TestBasicArithReduction, MultiType) { host_tshp->ptr()[3] = 22; for (auto mode : {Reduce::Mode::PRODUCT, Reduce::Mode::MAX, Reduce::Mode::MIN, - Reduce::Mode::SUM, Reduce::Mode::SUM_SQR}) { + Reduce::Mode::SUM, Reduce::Mode::SUM_SQR, Reduce::Mode::MEAN}) { check_mode(mode); } } @@ -682,18 +725,19 @@ TEST(TestBasicArithReduction, AutoCheck) { Param param; - auto make_graph = [¶m](const Checker::SymInpArray& inputs) + auto make_graph = [¶m](const Checker::SymInpArray& inputs, DType dtype) -> Checker::SymOutArray { auto inp = inputs[0]; auto tshp = inputs[1].symshape(); - inp = opr::TypeCvt::make(inp, dtype::Float16()); + inp = opr::TypeCvt::make(inp, dtype); return {opr::Reduce::make(inp, param, tshp)}; }; - auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { + auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp, + DType dtype) { auto cn = inp[0]->storage().comp_node(); TensorShape out_shape = inp[1]->shape(); dest[0] = HostTensorND{cn, out_shape, dtype::Float32()}; - HostTensorND tmp_inp{cn, inp[0]->shape(), dtype::Float16()}; + HostTensorND tmp_inp{cn, inp[0]->shape(), dtype}; HostTensorND new_inp{cn, inp[0]->shape(), dtype::Float32()}; auto typecvt = megdnn_naive_handle()->create_operator(); @@ -711,31 +755,38 @@ TEST(TestBasicArithReduction, AutoCheck) { dispatch_by_mode(ctype, Mode::MAX, in, out); \ dispatch_by_mode(ctype, Mode::SUM, in, out); \ dispatch_by_mode(ctype, Mode::PRODUCT, in, out); \ - dispatch_by_mode(ctype, Mode::SUM_SQR, in, out); + dispatch_by_mode(ctype, Mode::SUM_SQR, in, out); \ + dispatch_by_mode(ctype, Mode::MEAN, in, out); - mgb_assert(param.data_type == Param::DataType::FLOAT_O16xC32 || - param.data_type == Param::DataType::FLOAT_O32xC32); + mgb_assert(param.data_type == Param::DataType::FLOAT_O32xC32); dispatch_by_dtype(dtype::Float32, new_inp, dest[0]); #undef dispatch_by_mode #undef dispatch_by_dtype }; - auto check = [&](Mode mode, Param::DataType data_type) { + auto check = [&](Mode mode, Param::DataType data_type, DType dtype) { param.mode = mode; param.data_type = data_type; Checker::RunOptions opts; opts.outputs_max_err = 1e-3; opts.numdiff_max_err = 5e-1; - Checker(make_graph, fwd) - .set_input_allow_grad(1, false) - .run({TensorShape{22, 21}, {22, 1}}, opts) - .run({TensorShape{22, 21}, {1, 1}}, opts) - .run({TensorShape{22, 21}, {22, 1}}, opts); + using namespace std::placeholders; + Checker checker(std::bind(make_graph, _1, dtype), + std::bind(fwd, _1, _2, dtype)); + if (dtype.category() == DTypeCategory::FLOAT) { + checker.set_input_allow_grad(1, false); + } else { + checker.disable_grad_check(); + } + checker.run({TensorShape{22, 21}, {22, 1}}, opts) + .run({TensorShape{22, 21}, {1, 1}}, opts) + .run({TensorShape{22, 21}, {22, 1}}, opts); }; for (auto mode : - {Mode::SUM, Mode::MAX, Mode::MIN, Mode::PRODUCT}) { - check(mode, Param::DataType::FLOAT_O32xC32); + {Mode::SUM, Mode::MAX, Mode::MIN, Mode::PRODUCT, Mode::MEAN}) { + check(mode, Param::DataType::FLOAT_O32xC32, dtype::Float16()); + check(mode, Param::DataType::FLOAT_O32xC32, dtype::Int32()); } } @@ -747,6 +798,7 @@ OPR_TEST(SUM_SQR) OPR_TEST(PRODUCT) OPR_TEST(MAX) OPR_TEST(MIN) +OPR_TEST(MEAN) TEST(TestBasicArithReduction, CompSeqRecordLevel2) { HostTensorGenerator<> gen;