GitOrigin-RevId: 88b37123a8
tags/v1.6.0-rc1
@@ -13,7 +13,7 @@ import pytest | |||
from utils import opr_test | |||
import megengine.functional as F | |||
from megengine import tensor | |||
from megengine import jit, tensor | |||
def common_test_reduce(opr, ref_opr): | |||
@@ -204,3 +204,28 @@ def test_topk(descending, sorted, inp1d, kth_only): | |||
if not sorted: | |||
values = np_sort(values) | |||
np.testing.assert_equal(values, np_sort(data)[..., :k]) | |||
@pytest.mark.parametrize("is_trace", [True, False]) | |||
def test_reduce_on_empty_tensor(is_trace): | |||
dtypes = [np.float32, np.int32, np.bool] | |||
inputs = [ | |||
(np.random.random((0,)), None), | |||
(np.random.random((3, 0, 2)), 1), | |||
(np.random.random((10, 10, 0, 10)), 0), | |||
] | |||
def run_test(fn, ref_fn, input, dtype, axis=None, symbolic=False): | |||
if is_trace: | |||
fn = jit.trace(symbolic=symbolic)(fn) | |||
for i in range(3): | |||
out = fn(tensor(input, dtype=dtype), axis=axis).numpy() | |||
out_ref = ref_fn(input.astype(dtype), axis=axis) | |||
np.testing.assert_equal(out, out_ref) | |||
for dtype in dtypes: | |||
for inp, axis in inputs: | |||
run_test(F.sum, np.sum, inp, dtype, axis, True) | |||
run_test(F.sum, np.sum, inp, dtype, axis, False) | |||
run_test(F.prod, np.prod, inp, dtype, axis, True) | |||
run_test(F.prod, np.prod, inp, dtype, axis, False) |
@@ -84,6 +84,11 @@ public: | |||
auto&& dev_tensor = tensor.dev_tensor(); | |||
var->m_comp_node = dev_tensor.comp_node(); | |||
var->m_shape = dev_tensor.shape(); | |||
if (dev_tensor.empty()) { | |||
auto layout = dev_tensor.layout(); | |||
layout.init_contiguous_stride(); | |||
dev_tensor.reset(dev_tensor.storage(), layout); | |||
} | |||
var->m_dev_tensor = dev_tensor; | |||
var->m_mem_plan.reset_from_owner_var().chunk() | |||
.mem_alloc_status.set_from_owner_var(); | |||
@@ -1364,7 +1364,7 @@ TEST(TestGraph, EmptyShapeCheck) { | |||
using Param = opr::CondTake::Param; | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||
y = opr::CondTake::make(x, x, {Param::Mode::GT})[0], | |||
z = opr::reduce_sum(y, y.make_scalar(1)); | |||
z = opr::reduce_max(y, y.make_scalar(1)); | |||
HostTensorND host_z; | |||
auto func = graph->compile({make_callback_copy(z, host_z)}); | |||
func->execute(); | |||
@@ -1377,7 +1377,7 @@ TEST(TestGraph, EmptyShapeCheck) { | |||
func->execute(); | |||
} catch (const MegBrainError& exc) { | |||
std::string msg{exc.what()}; | |||
ASSERT_TRUE(msg.find("empty output var") != | |||
ASSERT_TRUE(msg.find("empty input is not allowed") != | |||
std::string::npos) | |||
<< "bad message " << msg; | |||
throw; | |||
@@ -2413,8 +2413,6 @@ TEST(TestMemReuse, ResetEmptyDevTensor) { | |||
y = opr::Reduce::make(x, {opr::Reduce::Mode::MAX, 0}); | |||
HostTensorND host_y; | |||
auto func = g->compile({make_callback_copy(y, host_y)}); | |||
auto &&recv = x.node()->owner_graph()->var_receiver_in_current_comp_seq(x.node()); | |||
ASSERT_TRUE(!recv.is_empty_allowed()); | |||
if (inp_shp.is_empty()) { | |||
ASSERT_ANY_THROW(func->execute().wait()); | |||
} else { | |||
@@ -1072,6 +1072,7 @@ class Reduce::KernScheduler { | |||
m_apply_side_effect; | |||
std::unique_ptr<megdnn::Elemwise> m_elemwise_trans_opr; | |||
std::unique_ptr<megdnn::TypeCvt> m_typecvt_opr; | |||
std::unique_ptr<megdnn::Fill> m_fill_opr; | |||
DeviceTensorND m_side_affect_wkspc; | |||
}; | |||
@@ -1338,6 +1339,47 @@ void Reduce::KernScheduler::execute( | |||
} | |||
mgb_assert(!m_kern_param.empty()); | |||
// empty input | |||
if (input.shape_valid() && input.empty()) { | |||
auto mode = m_kern_param[0].kparam.mode; | |||
if (!m_fill_opr) { | |||
m_fill_opr = intl::get_megdnn_handle(dest.comp_node())-> | |||
create_operator<megdnn::Fill>(); | |||
} | |||
std::string err_msg; | |||
switch (mode) { | |||
case Reduce::Mode::SUM: | |||
if (!dest.empty()) { | |||
m_fill_opr->param() = 0; | |||
m_fill_opr->exec(dest.as_megdnn(), {}); | |||
} | |||
break; | |||
case Reduce::Mode::PRODUCT: | |||
if (!dest.empty()) { | |||
m_fill_opr->param() = 1; | |||
m_fill_opr->exec(dest.as_megdnn(), {}); | |||
} | |||
break; | |||
case Reduce::Mode::MEAN: | |||
err_msg = "mean"; break; | |||
case Reduce::Mode::MIN: | |||
err_msg = "min"; break; | |||
case Reduce::Mode::MAX: | |||
err_msg = "max"; break; | |||
case Reduce::Mode::SUM_SQR: | |||
err_msg = "sum_sqr"; break; | |||
default: | |||
mgb_throw(MegBrainError, "bad reduce mode"); | |||
} | |||
if (!err_msg.empty()) { | |||
mgb_throw( | |||
MegBrainError, | |||
"empty input is not allowed for reduce mode: %s", | |||
err_msg.c_str()); | |||
} | |||
return; | |||
} | |||
mgb_assert(input.layout().is_contiguous() && | |||
input.raw_ptr() == m_kern_param[0].input.raw_ptr && | |||
dest.raw_ptr() == m_kern_param.back().output.raw_ptr); | |||
@@ -1425,7 +1467,9 @@ Reduce::Reduce(VarNode *inp, VarNode *target_shape, const Param ¶m, | |||
mgb_throw(GraphError, "invalid param data_type: %d", | |||
int(param.data_type)); | |||
} | |||
add_output(None)->dtype(out_dtype); | |||
add_output(None) | |||
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | |||
.dtype(out_dtype); | |||
cg::add_workspace_output(this); | |||
add_equivalence_component<PODHash<Param>>(&m_param); | |||
@@ -1703,6 +1747,13 @@ void Reduce::perform( | |||
ksched.execute(opr.get(), *input_contig, dest); | |||
} | |||
Reduce::NodeProp* Reduce::do_make_node_prop() const { | |||
auto ret = Super::do_make_node_prop(); | |||
ret->add_dep_type_existing_var(input(0), | |||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
return ret; | |||
} | |||
void Reduce::create_megdnn_opr() { | |||
set_megdnn_opr(intl::get_megdnn_handle(comp_node())-> | |||
create_operator<megdnn::Reduce>()); | |||
@@ -335,6 +335,7 @@ MGB_DEFINE_OPR_CLASS(Reduce, intl::DynamicOutputIfInputDynamic< | |||
void add_input_layout_constraint() override final; | |||
void scn_do_execute() override final; | |||
void init_output_static_infer_desc() override final; | |||
NodeProp* do_make_node_prop() const override; | |||
void create_megdnn_opr() override; | |||
void record_execute_deps(ExecDependencyArray& deps) override; | |||
@@ -900,4 +900,61 @@ TEST(TestBasicArithReduction, StaticInferValueDType) { | |||
run_test(F16, F16, ParamType::FLOAT_O16xC32); | |||
} | |||
TEST(TestBasicArithReduction, EmptyInput) { | |||
using Param = opr::Reduce::Param; | |||
using Mode = opr::Reduce::Mode; | |||
auto check_allow_empty = [](const Param& param, const TensorShape& inpshp, double target_val) { | |||
HostTensorGenerator<> gen; | |||
auto graph = ComputingGraph::make(); | |||
auto host_x = gen(inpshp); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||
y = opr::Reduce::make(x, param, {}); | |||
HostTensorND host_y; | |||
auto func = graph->compile({make_callback_copy(y, host_y)}); | |||
func->execute().wait(); | |||
if (!host_y.shape().is_empty()) { | |||
size_t size = host_y.layout().total_nr_elems(); | |||
#define cb(DType) \ | |||
if (host_y.layout().dtype == DType()) { \ | |||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||
auto ptr = host_y.ptr<ctype>(); \ | |||
ctype target = static_cast<ctype>(target_val); \ | |||
for (size_t i = 0; i < size; ++i) { \ | |||
ASSERT_TRUE(ptr[i] == target); \ | |||
} \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
#undef cb | |||
} else { | |||
ASSERT_TRUE(host_y.empty()); | |||
} | |||
}; | |||
auto check_forbid_empty = [](const Param& param, const TensorShape& inpshp) { | |||
HostTensorGenerator<> gen; | |||
auto graph = ComputingGraph::make(); | |||
auto host_x = gen(inpshp); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||
y = opr::Reduce::make(x, param, {}); | |||
HostTensorND host_y; | |||
auto func = graph->compile({make_callback_copy(y, host_y)}); | |||
ASSERT_ANY_THROW(func->execute().wait()); | |||
}; | |||
check_allow_empty({Mode::SUM, 0, {}}, {0}, 0); | |||
check_allow_empty({Mode::SUM, -1, {}}, {2, 0, 3}, 0); | |||
check_allow_empty({Mode::SUM, 1, {}}, {2, 0, 3}, 0); | |||
check_allow_empty({Mode::PRODUCT, 0, {}}, {0, 1, 2}, 1); | |||
check_allow_empty({Mode::PRODUCT, 1, {}}, {0, 0, 0}, 1); | |||
check_allow_empty({Mode::PRODUCT, 2, {}}, {0, 0, 0}, 1); | |||
check_forbid_empty({Mode::MAX, 0, {}}, {0}); | |||
check_forbid_empty({Mode::MIN, -1, {}}, {0, 1, 2}); | |||
check_forbid_empty({Mode::MEAN, 0, {}}, {0, 0}); | |||
check_forbid_empty({Mode::SUM_SQR, 1, {}}, {2, 1, 0}); | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |