GitOrigin-RevId: dfb401a945
release-1.6
@@ -18,6 +18,7 @@ import megengine.amp as amp | |||
import megengine.core.ops.builtin as builtin | |||
import megengine.core.tensor.dtype as dtype | |||
import megengine.functional as F | |||
import megengine.jit as jit | |||
from megengine import Parameter, Tensor, is_cuda_available, tensor | |||
from megengine.core._trace_option import use_symbolic_shape | |||
from megengine.core.autodiff.grad import Grad | |||
@@ -859,6 +860,35 @@ def test_condtake(): | |||
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) | |||
# @pytest.mark.parametrize("is_symbolic", [None, False, True]) | |||
def test_condtake(is_symbolic=None): | |||
shapes = [ | |||
(3, 3, 3), | |||
(0,), | |||
(3, 0, 3), | |||
] | |||
def fn(mask, data): | |||
return F.cond_take(mask, data) | |||
if is_symbolic is not None: | |||
fn = jit.trace(symbolic=is_symbolic)(fn) | |||
for shp in shapes: | |||
x_np = np.random.randn(*shp).astype("float32") | |||
mask_np = x_np > 0 | |||
x = tensor(x_np) | |||
mask = tensor(mask_np) | |||
ref_out = x_np[mask_np] | |||
ref_idx = mask_np.flatten().nonzero()[0] | |||
for i in range(3): | |||
out, idx = fn(mask, x) | |||
np.testing.assert_equal(out.numpy(), ref_out) | |||
np.testing.assert_equal(idx.numpy(), ref_idx) | |||
if is_symbolic is None: | |||
break | |||
def test_condtake_is_same(): | |||
op1 = builtin.CondTake() | |||
op2 = builtin.CondTake() | |||
@@ -45,25 +45,30 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
auto&& inp = inputs[0]; | |||
auto&& msk = inputs[1]; | |||
SmallVector<TensorPtr> out; | |||
mgb_assert(inp->layout().eq_shape(msk->layout()), | |||
"input shape does not match mask shape"); | |||
mgb_assert(msk->get_value().dtype().enumv() == DTypeEnum::Bool, | |||
"mask dtype must be bool"); | |||
DnnOprCaller<megdnn::CondTake> dnn_op(inp->comp_node()); | |||
dnn_op.op->param().val = 1; | |||
TensorLayout m_layout({dnn_op.op->get_workspace_in_bytes(inp->layout())}, | |||
dtype::Byte()); | |||
auto dnn_workspace = dnn_op.create_workspace(m_layout); | |||
MegDNNDynOutMallocImpl<2> policy{inp->comp_node()}; | |||
dnn_op.op->exec(inp->dev_tensor().as_megdnn(), | |||
msk->dev_tensor().as_megdnn(), | |||
dnn_workspace, | |||
&policy); | |||
SmallVector<TensorPtr> out; | |||
if (inp->layout().is_empty()) { | |||
// empty tensor | |||
policy.alloc_output(0, inp->layout().dtype, {0}, nullptr); | |||
policy.alloc_output(1, dtype::Int32(), {0}, nullptr); | |||
} else { | |||
DnnOprCaller<megdnn::CondTake> dnn_op(inp->comp_node()); | |||
dnn_op.op->param().val = 1; | |||
TensorLayout m_layout({dnn_op.op->get_workspace_in_bytes(inp->layout())}, | |||
dtype::Byte()); | |||
auto dnn_workspace = dnn_op.create_workspace(m_layout); | |||
dnn_op.op->exec(inp->dev_tensor().as_megdnn(), | |||
msk->dev_tensor().as_megdnn(), | |||
dnn_workspace, | |||
&policy); | |||
} | |||
out.push_back(policy.at(0)); | |||
out.push_back(policy.at(1)); | |||
return out; | |||
@@ -264,6 +264,15 @@ CondTake::CondTake(VarNode *data, VarNode *mask, | |||
} | |||
} | |||
CondTake::NodeProp* CondTake::do_make_node_prop() const { | |||
auto ret = Super::do_make_node_prop(); | |||
ret->add_dep_type_existing_var(input(0), | |||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
ret->add_dep_type_existing_var(input(1), | |||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
return ret; | |||
} | |||
#if MGB_ENABLE_GRAD | |||
MGB_IMPL_OPR_GRAD(CondTake) { | |||
mgb_assert(out_grad.size() == 3 && !out_grad[2]); | |||
@@ -305,11 +314,21 @@ void CondTake::add_input_layout_constraint() { | |||
} | |||
void CondTake::scn_do_execute() { | |||
auto&& data = input(0)->dev_tensor(); | |||
auto&& mask = input(1)->dev_tensor(); | |||
intl::MegDNNDynOutMallocImpl dyn_malloc{this, comp_node()}; | |||
megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), | |||
input(1)->dev_tensor().as_megdnn(), | |||
intl::get_megdnn_workspace_from_var(output().back()), | |||
&dyn_malloc); | |||
if (data.layout().is_empty()) { | |||
mgb_assert(data.layout().eq_shape(mask.layout()), | |||
"CondTake shape differs: data=%s mask=%s", | |||
data.layout().TensorShape::to_string().c_str(), | |||
mask.layout().TensorShape::to_string().c_str()); | |||
dyn_malloc.alloc_output(0, data.layout().dtype, {0}, nullptr); | |||
dyn_malloc.alloc_output(1, dtype::Int32(), {0}, nullptr); | |||
} else { | |||
megdnn_opr()->exec(data.as_megdnn(), mask.as_megdnn(), | |||
intl::get_megdnn_workspace_from_var(output().back()), | |||
&dyn_malloc); | |||
} | |||
} | |||
/* ================= TopK ================= */ | |||
@@ -151,6 +151,7 @@ MGB_DEFINE_OPR_CLASS(CondTake, intl::CondTakeBase) // { | |||
void init_output_static_infer_desc() override; | |||
void scn_do_execute() override; | |||
void add_input_layout_constraint() override; | |||
NodeProp* do_make_node_prop() const override; | |||
public: | |||
CondTake(VarNode *data, VarNode *mask, | |||
@@ -256,20 +256,25 @@ TEST(TestOprMisc, CondTake) { | |||
run(mki({100})); | |||
} | |||
TEST(TestOprMisc, CondTakeEmptyOut) { | |||
TEST(TestOprMisc, CondTakeEmptyIO) { | |||
using Param = opr::CondTake::Param; | |||
HostTensorGenerator<> gen; | |||
auto host_x = gen({1}); | |||
host_x->ptr<float>()[0] = 1; | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
auto out = opr::CondTake::make(x, x, {Param::Mode::LT}); | |||
HostTensorND host_out0, host_out1; | |||
auto func = graph->compile({make_callback_copy(out[0], host_out0), | |||
make_callback_copy(out[1], host_out1)}); | |||
func->execute(); | |||
ASSERT_EQ(TensorShape{0}, host_out0.shape()); | |||
ASSERT_EQ(TensorShape{0}, host_out1.shape()); | |||
auto check = [&](const TensorShape& shp) { | |||
auto host_x = gen(shp); | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
auto y = x + 1; | |||
auto out = opr::CondTake::make(x, y, {Param::Mode::EQ}); | |||
HostTensorND host_out0, host_out1; | |||
auto func = graph->compile({make_callback_copy(out[0], host_out0), | |||
make_callback_copy(out[1], host_out1)}); | |||
func->execute(); | |||
ASSERT_EQ(TensorShape{0}, host_out0.shape()); | |||
ASSERT_EQ(TensorShape{0}, host_out1.shape()); | |||
}; | |||
check({1}); | |||
check({0}); | |||
check({1, 0}); | |||
} | |||
TEST(TestOprMisc, TopKValueOnly) { | |||