Browse Source

fix(mge): ensure contiguous when passing tensor from graph rt to imperative rt

GitOrigin-RevId: e4d944343d
release-1.1
Megvii Engine Team 4 years ago
parent
commit
e8bf5bc002
3 changed files with 21 additions and 0 deletions
  1. +13
    -0
      imperative/python/test/unit/test_tracing.py
  2. +6
    -0
      imperative/src/impl/opr_utility.cpp
  3. +2
    -0
      imperative/src/include/megbrain/imperative/opr_utility.h

+ 13
- 0
imperative/python/test/unit/test_tracing.py View File

@@ -406,3 +406,16 @@ def test_clip():


for i in range(3): for i in range(3):
f(x, tensor([0]), tensor([1])) f(x, tensor([0]), tensor([1]))


# test returning noncontiguous tensor from trace
def test_slice():
@trace
def f(x):
return x[:, 1::2]

x = F.arange(8).reshape(2, 4)
f(x)
y = f(x)
np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2])
y + y

+ 6
- 0
imperative/src/impl/opr_utility.cpp View File

@@ -156,6 +156,12 @@ cg::OperatorNodeBase::NodeProp* OutputCallback::do_make_node_prop() const {
return prop; return prop;
} }


void OutputCallback::add_input_layout_constraint() {
if (m_param.require_contiguous) {
input(0)->add_layout_constraint_contiguous();
}
}

void OutputCallback::scn_do_execute() { void OutputCallback::scn_do_execute() {
if (m_use_host_value) { if (m_use_host_value) {
m_param.callback(owner_graph()->static_infer_manager().infer_value(input(0))); m_param.callback(owner_graph()->static_infer_manager().infer_value(input(0)));


+ 2
- 0
imperative/src/include/megbrain/imperative/opr_utility.h View File

@@ -62,6 +62,7 @@ public:
callback_t callback; callback_t callback;
bool borrow = false; // do not obtain shared ownership on DeviceTensorND bool borrow = false; // do not obtain shared ownership on DeviceTensorND
bool prefer_host_value = false; // use host value when possible bool prefer_host_value = false; // use host value when possible
bool require_contiguous = true;
}; };
OutputCallback(Param param, OutputCallback(Param param,
const VarNodeArray& inputs, const VarNodeArray& inputs,
@@ -80,6 +81,7 @@ protected:
void scn_do_execute() override; void scn_do_execute() override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
NodeProp* do_make_node_prop() const override; NodeProp* do_make_node_prop() const override;
void add_input_layout_constraint() override;
private: private:
Param m_param; Param m_param;
mutable bool m_use_host_value; mutable bool m_use_host_value;


Loading…
Cancel
Save