GitOrigin-RevId: 88b37123a8
tags/v1.6.0-rc1
@@ -13,7 +13,7 @@ import pytest | |||||
from utils import opr_test | from utils import opr_test | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import tensor | |||||
from megengine import jit, tensor | |||||
def common_test_reduce(opr, ref_opr): | def common_test_reduce(opr, ref_opr): | ||||
@@ -204,3 +204,28 @@ def test_topk(descending, sorted, inp1d, kth_only): | |||||
if not sorted: | if not sorted: | ||||
values = np_sort(values) | values = np_sort(values) | ||||
np.testing.assert_equal(values, np_sort(data)[..., :k]) | np.testing.assert_equal(values, np_sort(data)[..., :k]) | ||||
@pytest.mark.parametrize("is_trace", [True, False]) | |||||
def test_reduce_on_empty_tensor(is_trace): | |||||
dtypes = [np.float32, np.int32, np.bool] | |||||
inputs = [ | |||||
(np.random.random((0,)), None), | |||||
(np.random.random((3, 0, 2)), 1), | |||||
(np.random.random((10, 10, 0, 10)), 0), | |||||
] | |||||
def run_test(fn, ref_fn, input, dtype, axis=None, symbolic=False): | |||||
if is_trace: | |||||
fn = jit.trace(symbolic=symbolic)(fn) | |||||
for i in range(3): | |||||
out = fn(tensor(input, dtype=dtype), axis=axis).numpy() | |||||
out_ref = ref_fn(input.astype(dtype), axis=axis) | |||||
np.testing.assert_equal(out, out_ref) | |||||
for dtype in dtypes: | |||||
for inp, axis in inputs: | |||||
run_test(F.sum, np.sum, inp, dtype, axis, True) | |||||
run_test(F.sum, np.sum, inp, dtype, axis, False) | |||||
run_test(F.prod, np.prod, inp, dtype, axis, True) | |||||
run_test(F.prod, np.prod, inp, dtype, axis, False) |
@@ -84,6 +84,11 @@ public: | |||||
auto&& dev_tensor = tensor.dev_tensor(); | auto&& dev_tensor = tensor.dev_tensor(); | ||||
var->m_comp_node = dev_tensor.comp_node(); | var->m_comp_node = dev_tensor.comp_node(); | ||||
var->m_shape = dev_tensor.shape(); | var->m_shape = dev_tensor.shape(); | ||||
if (dev_tensor.empty()) { | |||||
auto layout = dev_tensor.layout(); | |||||
layout.init_contiguous_stride(); | |||||
dev_tensor.reset(dev_tensor.storage(), layout); | |||||
} | |||||
var->m_dev_tensor = dev_tensor; | var->m_dev_tensor = dev_tensor; | ||||
var->m_mem_plan.reset_from_owner_var().chunk() | var->m_mem_plan.reset_from_owner_var().chunk() | ||||
.mem_alloc_status.set_from_owner_var(); | .mem_alloc_status.set_from_owner_var(); | ||||
@@ -1364,7 +1364,7 @@ TEST(TestGraph, EmptyShapeCheck) { | |||||
using Param = opr::CondTake::Param; | using Param = opr::CondTake::Param; | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | auto x = opr::Host2DeviceCopy::make(*graph, host_x), | ||||
y = opr::CondTake::make(x, x, {Param::Mode::GT})[0], | y = opr::CondTake::make(x, x, {Param::Mode::GT})[0], | ||||
z = opr::reduce_sum(y, y.make_scalar(1)); | |||||
z = opr::reduce_max(y, y.make_scalar(1)); | |||||
HostTensorND host_z; | HostTensorND host_z; | ||||
auto func = graph->compile({make_callback_copy(z, host_z)}); | auto func = graph->compile({make_callback_copy(z, host_z)}); | ||||
func->execute(); | func->execute(); | ||||
@@ -1377,7 +1377,7 @@ TEST(TestGraph, EmptyShapeCheck) { | |||||
func->execute(); | func->execute(); | ||||
} catch (const MegBrainError& exc) { | } catch (const MegBrainError& exc) { | ||||
std::string msg{exc.what()}; | std::string msg{exc.what()}; | ||||
ASSERT_TRUE(msg.find("empty output var") != | |||||
ASSERT_TRUE(msg.find("empty input is not allowed") != | |||||
std::string::npos) | std::string::npos) | ||||
<< "bad message " << msg; | << "bad message " << msg; | ||||
throw; | throw; | ||||
@@ -2413,8 +2413,6 @@ TEST(TestMemReuse, ResetEmptyDevTensor) { | |||||
y = opr::Reduce::make(x, {opr::Reduce::Mode::MAX, 0}); | y = opr::Reduce::make(x, {opr::Reduce::Mode::MAX, 0}); | ||||
HostTensorND host_y; | HostTensorND host_y; | ||||
auto func = g->compile({make_callback_copy(y, host_y)}); | auto func = g->compile({make_callback_copy(y, host_y)}); | ||||
auto &&recv = x.node()->owner_graph()->var_receiver_in_current_comp_seq(x.node()); | |||||
ASSERT_TRUE(!recv.is_empty_allowed()); | |||||
if (inp_shp.is_empty()) { | if (inp_shp.is_empty()) { | ||||
ASSERT_ANY_THROW(func->execute().wait()); | ASSERT_ANY_THROW(func->execute().wait()); | ||||
} else { | } else { | ||||
@@ -1072,6 +1072,7 @@ class Reduce::KernScheduler { | |||||
m_apply_side_effect; | m_apply_side_effect; | ||||
std::unique_ptr<megdnn::Elemwise> m_elemwise_trans_opr; | std::unique_ptr<megdnn::Elemwise> m_elemwise_trans_opr; | ||||
std::unique_ptr<megdnn::TypeCvt> m_typecvt_opr; | std::unique_ptr<megdnn::TypeCvt> m_typecvt_opr; | ||||
std::unique_ptr<megdnn::Fill> m_fill_opr; | |||||
DeviceTensorND m_side_affect_wkspc; | DeviceTensorND m_side_affect_wkspc; | ||||
}; | }; | ||||
@@ -1338,6 +1339,47 @@ void Reduce::KernScheduler::execute( | |||||
} | } | ||||
mgb_assert(!m_kern_param.empty()); | mgb_assert(!m_kern_param.empty()); | ||||
// empty input | |||||
if (input.shape_valid() && input.empty()) { | |||||
auto mode = m_kern_param[0].kparam.mode; | |||||
if (!m_fill_opr) { | |||||
m_fill_opr = intl::get_megdnn_handle(dest.comp_node())-> | |||||
create_operator<megdnn::Fill>(); | |||||
} | |||||
std::string err_msg; | |||||
switch (mode) { | |||||
case Reduce::Mode::SUM: | |||||
if (!dest.empty()) { | |||||
m_fill_opr->param() = 0; | |||||
m_fill_opr->exec(dest.as_megdnn(), {}); | |||||
} | |||||
break; | |||||
case Reduce::Mode::PRODUCT: | |||||
if (!dest.empty()) { | |||||
m_fill_opr->param() = 1; | |||||
m_fill_opr->exec(dest.as_megdnn(), {}); | |||||
} | |||||
break; | |||||
case Reduce::Mode::MEAN: | |||||
err_msg = "mean"; break; | |||||
case Reduce::Mode::MIN: | |||||
err_msg = "min"; break; | |||||
case Reduce::Mode::MAX: | |||||
err_msg = "max"; break; | |||||
case Reduce::Mode::SUM_SQR: | |||||
err_msg = "sum_sqr"; break; | |||||
default: | |||||
mgb_throw(MegBrainError, "bad reduce mode"); | |||||
} | |||||
if (!err_msg.empty()) { | |||||
mgb_throw( | |||||
MegBrainError, | |||||
"empty input is not allowed for reduce mode: %s", | |||||
err_msg.c_str()); | |||||
} | |||||
return; | |||||
} | |||||
mgb_assert(input.layout().is_contiguous() && | mgb_assert(input.layout().is_contiguous() && | ||||
input.raw_ptr() == m_kern_param[0].input.raw_ptr && | input.raw_ptr() == m_kern_param[0].input.raw_ptr && | ||||
dest.raw_ptr() == m_kern_param.back().output.raw_ptr); | dest.raw_ptr() == m_kern_param.back().output.raw_ptr); | ||||
@@ -1425,7 +1467,9 @@ Reduce::Reduce(VarNode *inp, VarNode *target_shape, const Param ¶m, | |||||
mgb_throw(GraphError, "invalid param data_type: %d", | mgb_throw(GraphError, "invalid param data_type: %d", | ||||
int(param.data_type)); | int(param.data_type)); | ||||
} | } | ||||
add_output(None)->dtype(out_dtype); | |||||
add_output(None) | |||||
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | |||||
.dtype(out_dtype); | |||||
cg::add_workspace_output(this); | cg::add_workspace_output(this); | ||||
add_equivalence_component<PODHash<Param>>(&m_param); | add_equivalence_component<PODHash<Param>>(&m_param); | ||||
@@ -1703,6 +1747,13 @@ void Reduce::perform( | |||||
ksched.execute(opr.get(), *input_contig, dest); | ksched.execute(opr.get(), *input_contig, dest); | ||||
} | } | ||||
Reduce::NodeProp* Reduce::do_make_node_prop() const { | |||||
auto ret = Super::do_make_node_prop(); | |||||
ret->add_dep_type_existing_var(input(0), | |||||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
return ret; | |||||
} | |||||
void Reduce::create_megdnn_opr() { | void Reduce::create_megdnn_opr() { | ||||
set_megdnn_opr(intl::get_megdnn_handle(comp_node())-> | set_megdnn_opr(intl::get_megdnn_handle(comp_node())-> | ||||
create_operator<megdnn::Reduce>()); | create_operator<megdnn::Reduce>()); | ||||
@@ -335,6 +335,7 @@ MGB_DEFINE_OPR_CLASS(Reduce, intl::DynamicOutputIfInputDynamic< | |||||
void add_input_layout_constraint() override final; | void add_input_layout_constraint() override final; | ||||
void scn_do_execute() override final; | void scn_do_execute() override final; | ||||
void init_output_static_infer_desc() override final; | void init_output_static_infer_desc() override final; | ||||
NodeProp* do_make_node_prop() const override; | |||||
void create_megdnn_opr() override; | void create_megdnn_opr() override; | ||||
void record_execute_deps(ExecDependencyArray& deps) override; | void record_execute_deps(ExecDependencyArray& deps) override; | ||||
@@ -900,4 +900,61 @@ TEST(TestBasicArithReduction, StaticInferValueDType) { | |||||
run_test(F16, F16, ParamType::FLOAT_O16xC32); | run_test(F16, F16, ParamType::FLOAT_O16xC32); | ||||
} | } | ||||
TEST(TestBasicArithReduction, EmptyInput) { | |||||
using Param = opr::Reduce::Param; | |||||
using Mode = opr::Reduce::Mode; | |||||
auto check_allow_empty = [](const Param& param, const TensorShape& inpshp, double target_val) { | |||||
HostTensorGenerator<> gen; | |||||
auto graph = ComputingGraph::make(); | |||||
auto host_x = gen(inpshp); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||||
y = opr::Reduce::make(x, param, {}); | |||||
HostTensorND host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y)}); | |||||
func->execute().wait(); | |||||
if (!host_y.shape().is_empty()) { | |||||
size_t size = host_y.layout().total_nr_elems(); | |||||
#define cb(DType) \ | |||||
if (host_y.layout().dtype == DType()) { \ | |||||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||||
auto ptr = host_y.ptr<ctype>(); \ | |||||
ctype target = static_cast<ctype>(target_val); \ | |||||
for (size_t i = 0; i < size; ++i) { \ | |||||
ASSERT_TRUE(ptr[i] == target); \ | |||||
} \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
#undef cb | |||||
} else { | |||||
ASSERT_TRUE(host_y.empty()); | |||||
} | |||||
}; | |||||
auto check_forbid_empty = [](const Param& param, const TensorShape& inpshp) { | |||||
HostTensorGenerator<> gen; | |||||
auto graph = ComputingGraph::make(); | |||||
auto host_x = gen(inpshp); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||||
y = opr::Reduce::make(x, param, {}); | |||||
HostTensorND host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y)}); | |||||
ASSERT_ANY_THROW(func->execute().wait()); | |||||
}; | |||||
check_allow_empty({Mode::SUM, 0, {}}, {0}, 0); | |||||
check_allow_empty({Mode::SUM, -1, {}}, {2, 0, 3}, 0); | |||||
check_allow_empty({Mode::SUM, 1, {}}, {2, 0, 3}, 0); | |||||
check_allow_empty({Mode::PRODUCT, 0, {}}, {0, 1, 2}, 1); | |||||
check_allow_empty({Mode::PRODUCT, 1, {}}, {0, 0, 0}, 1); | |||||
check_allow_empty({Mode::PRODUCT, 2, {}}, {0, 0, 0}, 1); | |||||
check_forbid_empty({Mode::MAX, 0, {}}, {0}); | |||||
check_forbid_empty({Mode::MIN, -1, {}}, {0, 1, 2}); | |||||
check_forbid_empty({Mode::MEAN, 0, {}}, {0, 0}); | |||||
check_forbid_empty({Mode::SUM_SQR, 1, {}}, {2, 1, 0}); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |