|
|
@@ -298,7 +298,8 @@ namespace { |
|
|
|
static_calc_x.copy_from(*host_x); |
|
|
|
opr::Reduce::perform( |
|
|
|
Mode::SUM, static_calc_y, static_calc_workspace, |
|
|
|
static_calc_x, oshp, static_calc_opr); |
|
|
|
static_calc_x, dtype::Float32(), oshp, |
|
|
|
static_calc_opr); |
|
|
|
host_y.ptr<float>()[0] ++; |
|
|
|
host_y.copy_from(static_calc_y); |
|
|
|
MGB_ASSERT_TENSOR_NEAR(expected, host_y, 1e-5); |
|
|
@@ -468,7 +469,8 @@ TEST(TestBasicArithReduction, NonContPerform) { |
|
|
|
for (auto &&tshp: |
|
|
|
TensorShapeArray{{5, 1}, {1, 5}, {1, 1}, {1}, {5, 5}}) { |
|
|
|
|
|
|
|
opr::Reduce::perform(mode, y, workspace, x, tshp, opr); |
|
|
|
opr::Reduce::perform(mode, y, workspace, x, dtype::Float32(), tshp, |
|
|
|
opr); |
|
|
|
ASSERT_TRUE(y.layout().is_contiguous()); |
|
|
|
ASSERT_EQ(tshp, y.shape()); |
|
|
|
size_t nr = tshp.total_nr_elems(); |
|
|
@@ -866,4 +868,36 @@ TEST(TestBasicArithReduction, StaticInferValue) { |
|
|
|
MGB_ASSERT_TENSOR_EQ(inferred, expected); |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestBasicArithReduction, StaticInferValueDType) { |
|
|
|
using ParamType = opr::Reduce::Param::DataType; |
|
|
|
DType F32 = dtype::Float32(), F16 = dtype::Float16(); |
|
|
|
|
|
|
|
auto run_test = [](const DType& itype, const DType& expected_otype, |
|
|
|
ParamType param_dtype) { |
|
|
|
HostTensorGenerator<> gen; |
|
|
|
auto host_x = gen({2, 3, 4, 5}); |
|
|
|
auto host_tshp = std::make_shared<HostTensorND>(host_x->comp_node(), |
|
|
|
dtype::Int32()); |
|
|
|
host_tshp->resize({1}); |
|
|
|
host_tshp->ptr<int>()[0] = 1; |
|
|
|
|
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
auto x_f32 = opr::Host2DeviceCopy::make(*graph, host_x), |
|
|
|
x = opr::TypeCvt::make(x_f32, itype), |
|
|
|
tshp = opr::Host2DeviceCopy::make(*graph, host_tshp), |
|
|
|
y = opr::Reduce::make( |
|
|
|
x, {opr::Reduce::Mode::SUM, MEGDNN_MAX_NDIM, param_dtype}, |
|
|
|
tshp); |
|
|
|
auto inferred = graph->static_infer_manager().infer_value(y.node()); |
|
|
|
ASSERT_EQ(inferred.layout().dtype, expected_otype); |
|
|
|
}; |
|
|
|
|
|
|
|
run_test(F32, F32, ParamType::DEFAULT); |
|
|
|
run_test(F16, F16, ParamType::DEFAULT); |
|
|
|
run_test(F32, F32, ParamType::FLOAT_O32xC32); |
|
|
|
run_test(F16, F32, ParamType::FLOAT_O32xC32); |
|
|
|
run_test(F32, F16, ParamType::FLOAT_O16xC32); |
|
|
|
run_test(F16, F16, ParamType::FLOAT_O16xC32); |
|
|
|
} |
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |