@@ -15,6 +15,7 @@ from utils import make_tensor | |||||
import megengine | import megengine | ||||
import megengine.core.tensor.megbrain_graph as G | import megengine.core.tensor.megbrain_graph as G | ||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.jit as jit | |||||
from megengine.core._imperative_rt.core2 import apply | from megengine.core._imperative_rt.core2 import apply | ||||
from megengine.core._trace_option import use_symbolic_shape | from megengine.core._trace_option import use_symbolic_shape | ||||
from megengine.core.ops import builtin | from megengine.core.ops import builtin | ||||
@@ -584,3 +585,26 @@ def test_advance_indexing_with_bool(test_varnode): | |||||
np.testing.assert_equal( | np.testing.assert_equal( | ||||
a[:, b, 0:2, [True, False]], aa[:, bb, 0:2, [True, False]].numpy() | a[:, b, 0:2, [True, False]], aa[:, bb, 0:2, [True, False]].numpy() | ||||
) | ) | ||||
@pytest.mark.parametrize("symbolic", [True, False, None]) | |||||
def test_subtensor_on_empty_tensor(symbolic): | |||||
np_x = np.array([], dtype=np.float32).reshape(10, 0, 10) | |||||
mge_x = megengine.tensor(np_x) | |||||
def run_test(fn): | |||||
out_ref = fn(np_x) | |||||
if symbolic is not None: | |||||
fn = jit.trace(symbolic=symbolic)(fn) | |||||
for i in range(3): | |||||
out = fn(mge_x) | |||||
np.testing.assert_equal(out.numpy(), out_ref) | |||||
run_test(lambda x: x[0:1, :, :]) | |||||
run_test(lambda x: x[1:100:2, :, :]) | |||||
run_test(lambda x: x[-10:5:2, :, :]) | |||||
run_test(lambda x: x[5:1:-1, :, :]) | |||||
run_test(lambda x: x[3, 10:1:1, 5]) | |||||
run_test(lambda x: x[3, 10:1:1, 5:-1]) | |||||
run_test(lambda x: x[:100, :100, :100]) | |||||
run_test(lambda x: x[100:200, 300:400, 500:600]) |
@@ -133,27 +133,42 @@ SubTensorSpec Slice::apply(TensorLayout layout, int axis) const { | |||||
return "None"; | return "None"; | ||||
return std::to_string(v.val()); | return std::to_string(v.val()); | ||||
}; | }; | ||||
auto mod_size = [size_ax](ptrdiff_t v) { | |||||
auto mod_size = [size_ax](ptrdiff_t v)->ptrdiff_t { | |||||
if (size_ax == 0) return 0; | |||||
return v < 0 ? v + size_ax : v; | return v < 0 ? v + size_ax : v; | ||||
}; | }; | ||||
MGB_MARK_USED_VAR(tostr); | MGB_MARK_USED_VAR(tostr); | ||||
#define CHECK(cond) \ | |||||
mgb_assert(cond, \ | |||||
"index out of bound: layout=%s; request begin=%s end=%s step=%s " \ | |||||
"axis=%d", \ | |||||
layout.to_string().c_str(), tostr(m_begin).c_str(), \ | |||||
tostr(m_end).c_str(), tostr(m_step).c_str(), axis) | |||||
#define CHECK(cond) \ | |||||
if (m_is_scalar_idx) { \ | |||||
mgb_assert(cond, \ | |||||
"index out of bound: layout=%s; request index=%s, axis=%d", \ | |||||
layout.to_string().c_str(), tostr(m_begin).c_str(), axis); \ | |||||
} else { \ | |||||
mgb_assert(cond, \ | |||||
"index out of bound: layout=%s; request begin=%s end=%s step=%s " \ | |||||
"axis=%d", \ | |||||
layout.to_string().c_str(), tostr(m_begin).c_str(), \ | |||||
tostr(m_end).c_str(), tostr(m_step).c_str(), axis); \ | |||||
} | |||||
if (step > 0) { | if (step > 0) { | ||||
begin = mod_size(m_begin.val_with_default(0)); | begin = mod_size(m_begin.val_with_default(0)); | ||||
end = mod_size(m_end.val_with_default(size_ax)); | end = mod_size(m_end.val_with_default(size_ax)); | ||||
CHECK(begin >= 0 && end >= begin && end <= size_ax); | |||||
if (!m_is_scalar_idx) { | |||||
end = std::min(end, size_ax); | |||||
begin = std::min(begin, end); | |||||
} | |||||
CHECK(begin >= 0 && end >= begin && end <= size_ax) | |||||
} else { | } else { | ||||
begin = mod_size(m_begin.val_with_default(size_ax - 1)); | begin = mod_size(m_begin.val_with_default(size_ax - 1)); | ||||
end = m_end.valid() ? mod_size(m_end.val()) : -1; | end = m_end.valid() ? mod_size(m_end.val()) : -1; | ||||
if (!m_is_scalar_idx) { | |||||
begin = std::min(begin, std::max<ptrdiff_t>(size_ax-1, 0)); | |||||
end = std::min(end, begin); | |||||
} | |||||
CHECK(step < 0 && begin >= 0 && end <= begin && begin < size_ax && | CHECK(step < 0 && begin >= 0 && end <= begin && begin < size_ax && | ||||
end >= -1); | |||||
end >= -1) | |||||
} | } | ||||
auto step_abs = std::abs(step); | auto step_abs = std::abs(step); | ||||
layout.shape[axis] = (std::abs(end - begin) + step_abs - 1) / step_abs; | layout.shape[axis] = (std::abs(end - begin) + step_abs - 1) / step_abs; | ||||
@@ -83,16 +83,20 @@ class SubTensorSpec { | |||||
/*! | /*! | ||||
* \brief slice along some axis; index as in Python, with negative indices | * \brief slice along some axis; index as in Python, with negative indices | ||||
* supported | |||||
* supported. Scalar index can also be represented as a Slice, where | |||||
* m_begin = idx, m_end = idx+1 and m_step = 1. The flag m_is_scalar_idx | |||||
* indicates whether the Slice comes from a scalar index. | |||||
*/ | */ | ||||
class Slice { | class Slice { | ||||
Maybe<ptrdiff_t> m_begin, m_end, m_step; | Maybe<ptrdiff_t> m_begin, m_end, m_step; | ||||
bool m_is_scalar_idx; | |||||
public: | public: | ||||
Slice(Maybe<ptrdiff_t> begin = None, | Slice(Maybe<ptrdiff_t> begin = None, | ||||
Maybe<ptrdiff_t> end = None, | Maybe<ptrdiff_t> end = None, | ||||
Maybe<ptrdiff_t> step = None): | |||||
m_begin{begin}, m_end{end}, m_step{step} | |||||
Maybe<ptrdiff_t> step = None, | |||||
bool is_scalar_idx = false): | |||||
m_begin{begin}, m_end{end}, m_step{step}, m_is_scalar_idx{is_scalar_idx} | |||||
{ } | { } | ||||
/*! | /*! | ||||
@@ -178,7 +178,9 @@ SubTensorSpec FancyIndexingHelper::do_make_sub_spec( | |||||
i.axis.get_raw(), axis); | i.axis.get_raw(), axis); | ||||
prev_axis = axis; | prev_axis = axis; | ||||
Maybe<ptrdiff_t> begin, end, step; | Maybe<ptrdiff_t> begin, end, step; | ||||
bool is_scalar_idx = false; | |||||
if (i.idx.node()) { | if (i.idx.node()) { | ||||
is_scalar_idx = true; | |||||
if (!m_require_scalar_index) { | if (!m_require_scalar_index) { | ||||
continue; | continue; | ||||
} | } | ||||
@@ -195,7 +197,7 @@ SubTensorSpec FancyIndexingHelper::do_make_sub_spec( | |||||
step = next_iv(); | step = next_iv(); | ||||
} | } | ||||
spec.merge_with(Slice(begin, end, step).apply(spec.layout(), axis)); | |||||
spec.merge_with(Slice(begin, end, step, is_scalar_idx).apply(spec.layout(), axis)); | |||||
} | } | ||||
mgb_assert(iv_iter == m_value_infer_result.end()); | mgb_assert(iv_iter == m_value_infer_result.end()); | ||||
@@ -660,7 +660,19 @@ MGB_IMPL_OPR_GRAD(AxisAddRemove) { | |||||
/* f{{{ ======================= Subtensor ======================= */ | /* f{{{ ======================= Subtensor ======================= */ | ||||
MGB_IMPL_FANCY_INDEXING_OPR_GET(Subtensor, "subtensor", true); | |||||
Subtensor::Subtensor(VarNode *inp, const IndexDesc &desc, | |||||
const OperatorNodeConfig &config): | |||||
Super({inp->owner_graph(), config, "subtensor", {inp}}, | |||||
inp, nullptr, desc, true) { | |||||
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
} | |||||
SymbolVar Subtensor::make(SymbolVar inp, const IndexDesc &desc, | |||||
const OperatorNodeConfig &config) { | |||||
return inp.insert_single_output_opr<Subtensor>(inp.node(), desc, config); | |||||
} | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Subtensor); | |||||
#if MGB_ENABLE_GRAD | #if MGB_ENABLE_GRAD | ||||
MGB_IMPL_OPR_GRAD(Subtensor) { | MGB_IMPL_OPR_GRAD(Subtensor) { | ||||
@@ -722,6 +734,13 @@ void Subtensor::init_rt_force_dynamic_mem_alloc_imply_chain() { | |||||
out->add_rt_force_dynamic_mem_alloc_imply_chain(inp); | out->add_rt_force_dynamic_mem_alloc_imply_chain(inp); | ||||
} | } | ||||
Subtensor::NodeProp* Subtensor::do_make_node_prop() const { | |||||
auto ret = Super::do_make_node_prop(); | |||||
ret->add_dep_type_existing_var(input(0), | |||||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
return ret; | |||||
} | |||||
// f}}} | // f}}} | ||||
/* f{{{ ================== ModifySubtensorImplHelper ================== */ | /* f{{{ ================== ModifySubtensorImplHelper ================== */ | ||||
@@ -358,6 +358,7 @@ MGB_DEFINE_OPR_CLASS(Subtensor, | |||||
void scn_do_execute() override; | void scn_do_execute() override; | ||||
void mem_plan_fwd_in2out_readonly() override; | void mem_plan_fwd_in2out_readonly() override; | ||||
void init_rt_force_dynamic_mem_alloc_imply_chain() override; | void init_rt_force_dynamic_mem_alloc_imply_chain() override; | ||||
NodeProp* do_make_node_prop() const override; | |||||
public: | public: | ||||
Subtensor(VarNode *inp, const IndexDesc &desc, | Subtensor(VarNode *inp, const IndexDesc &desc, | ||||
@@ -894,6 +894,47 @@ TEST(TestTensorManip, SubtensorIdxChange) { | |||||
run(false); | run(false); | ||||
} | } | ||||
TEST(TestTensorManip, SubtensorEmptyIO) { | |||||
using AIdx = opr::Subtensor::AxisIndexer; | |||||
using IndexDesc = std::vector<AIdx>; | |||||
using IndexDescCreater = thin_function<IndexDesc(SymbolVar)>; | |||||
HostTensorGenerator<> gen; | |||||
auto run = [&](const TensorShape& inp_shp, const TensorShape& out_shp, const IndexDescCreater& c) { | |||||
auto host_x = gen(inp_shp); | |||||
auto graph = ComputingGraph::make(); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||||
auto y = opr::Subtensor::make(x, c(x)); | |||||
HostTensorND host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y)}); | |||||
func->execute(); | |||||
ASSERT_EQ(host_y.shape(), out_shp); | |||||
ASSERT_TRUE(host_y.empty()); | |||||
}; | |||||
// x.shape = {0}, x[:0] | |||||
run({0}, {0}, [&](SymbolVar x)->IndexDesc { | |||||
return {AIdx::make_interval(0, None, x.make_scalar(0), None)}; | |||||
}); | |||||
// x.shape = {100, 0}, x[0:-10:2] | |||||
run({100, 0}, {45, 0}, [&](SymbolVar x)->IndexDesc { | |||||
return {AIdx::make_interval(0, x.make_scalar(0), x.make_scalar(-10), x.make_scalar(2))}; | |||||
}); | |||||
// x.shape = {100, 0}, x[10:-10:2, 0:0] | |||||
run({100, 0}, {40, 0}, [&](SymbolVar x)->IndexDesc { | |||||
return {AIdx::make_interval(0, x.make_scalar(10), x.make_scalar(-10), x.make_scalar(2)), | |||||
AIdx::make_interval(1, x.make_scalar(0), x.make_scalar(0), None)}; | |||||
}); | |||||
// x.shape = {10, 0, 10}, x[5, 10:-10:-2] | |||||
run({10, 0, 10}, {0, 10}, [&](SymbolVar x)->IndexDesc { | |||||
return {AIdx::make_index(0, x.make_scalar(5)), | |||||
AIdx::make_interval(1, x.make_scalar(10), x.make_scalar(-10), x.make_scalar(2))}; | |||||
}); | |||||
// x.shape = {10}, x[100:] | |||||
run({10}, {0}, [&](SymbolVar x)->IndexDesc { | |||||
return {AIdx::make_interval(0, x.make_scalar(100), None, None)}; | |||||
}); | |||||
} | |||||
namespace { | namespace { | ||||
void test_subtensor_fwdonly(bool dyn_inp, bool dyn_idx) { | void test_subtensor_fwdonly(bool dyn_inp, bool dyn_idx) { | ||||