Browse Source

fix(ops): implement from_op_node for reshape

GitOrigin-RevId: 4c99438504
release-1.10
Megvii Engine Team 3 years ago
parent
commit
4b27e861f4
2 changed files with 8 additions and 2 deletions
  1. +6
    -0
      imperative/src/impl/ops/broadcast.cpp
  2. +2
    -2
      imperative/src/impl/ops/tensor_manip.cpp

+ 6
- 0
imperative/src/impl/ops/broadcast.cpp View File

@@ -125,6 +125,11 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)


namespace reshape { namespace reshape {


auto make_from_op_node(const cg::OperatorNodeBase* node) {
auto& opr = node->cast_final_safe<opr::Reshape>();
return Reshape::make(opr.param(), std::vector<int32_t>());
}

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Reshape&>(def); auto&& op = static_cast<const Reshape&>(def);
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
@@ -261,6 +266,7 @@ OP_TRAIT_REG(Reshape, Reshape)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor) .apply_on_physical_tensor(apply_on_physical_tensor)
.get_input_layout_constraint(get_input_layout_constraint) .get_input_layout_constraint(get_input_layout_constraint)
.make_from_op_node(make_from_op_node)
.fallback(); .fallback();
} // namespace reshape } // namespace reshape




+ 2
- 2
imperative/src/impl/ops/tensor_manip.cpp View File

@@ -87,7 +87,7 @@ HostTensorND get_var_shape_host_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<DeviceTensorND> input_tensornds; SmallVector<DeviceTensorND> input_tensornds;
for (auto&& inp : inputs) { for (auto&& inp : inputs) {
input_tensornds.push_back(inp->dev_tensor());
input_tensornds.push_back(inp->dev_tensor(false));
} }
SmallVector<DeviceTensorND> output_tensornds = { SmallVector<DeviceTensorND> output_tensornds = {
{CompNode::default_cpu(), dtype::Int32()}}; {CompNode::default_cpu(), dtype::Int32()}};
@@ -100,7 +100,7 @@ HostTensorND get_var_shape_host_tensor(
SmallVector<TensorPtr> apply_on_physical_tensor( SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs, const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
return {Tensor::make(std::move(get_var_shape_host_tensor(def, inputs)))};
return {Tensor::make(get_var_shape_host_tensor(def, inputs))};
} }


std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(


Loading…
Cancel
Save