GitOrigin-RevId: e4d944343d
release-1.1
@@ -406,3 +406,16 @@ def test_clip(): | |||||
for i in range(3): | for i in range(3): | ||||
f(x, tensor([0]), tensor([1])) | f(x, tensor([0]), tensor([1])) | ||||
# test returning noncontiguous tensor from trace | |||||
def test_slice(): | |||||
@trace | |||||
def f(x): | |||||
return x[:, 1::2] | |||||
x = F.arange(8).reshape(2, 4) | |||||
f(x) | |||||
y = f(x) | |||||
np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2]) | |||||
y + y |
@@ -156,6 +156,12 @@ cg::OperatorNodeBase::NodeProp* OutputCallback::do_make_node_prop() const { | |||||
return prop; | return prop; | ||||
} | } | ||||
void OutputCallback::add_input_layout_constraint() { | |||||
if (m_param.require_contiguous) { | |||||
input(0)->add_layout_constraint_contiguous(); | |||||
} | |||||
} | |||||
void OutputCallback::scn_do_execute() { | void OutputCallback::scn_do_execute() { | ||||
if (m_use_host_value) { | if (m_use_host_value) { | ||||
m_param.callback(owner_graph()->static_infer_manager().infer_value(input(0))); | m_param.callback(owner_graph()->static_infer_manager().infer_value(input(0))); | ||||
@@ -62,6 +62,7 @@ public: | |||||
callback_t callback; | callback_t callback; | ||||
bool borrow = false; // do not obtain shared ownership on DeviceTensorND | bool borrow = false; // do not obtain shared ownership on DeviceTensorND | ||||
bool prefer_host_value = false; // use host value when possible | bool prefer_host_value = false; // use host value when possible | ||||
bool require_contiguous = true; | |||||
}; | }; | ||||
OutputCallback(Param param, | OutputCallback(Param param, | ||||
const VarNodeArray& inputs, | const VarNodeArray& inputs, | ||||
@@ -80,6 +81,7 @@ protected: | |||||
void scn_do_execute() override; | void scn_do_execute() override; | ||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
NodeProp* do_make_node_prop() const override; | NodeProp* do_make_node_prop() const override; | ||||
void add_input_layout_constraint() override; | |||||
private: | private: | ||||
Param m_param; | Param m_param; | ||||
mutable bool m_use_host_value; | mutable bool m_use_host_value; | ||||