Browse Source

fix(ops): implement from_op_node for reshape

GitOrigin-RevId: 4c99438504
release-1.10
Megvii Engine Team 3 years ago
parent
commit
4b27e861f4
2 changed files with 8 additions and 2 deletions
  1. +6
    -0
      imperative/src/impl/ops/broadcast.cpp
  2. +2
    -2
      imperative/src/impl/ops/tensor_manip.cpp

+ 6
- 0
imperative/src/impl/ops/broadcast.cpp View File

@@ -125,6 +125,11 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)

namespace reshape {

auto make_from_op_node(const cg::OperatorNodeBase* node) {
auto& opr = node->cast_final_safe<opr::Reshape>();
return Reshape::make(opr.param(), std::vector<int32_t>());
}

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Reshape&>(def);
mgb_assert(inputs.size() == 2);
@@ -261,6 +266,7 @@ OP_TRAIT_REG(Reshape, Reshape)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.get_input_layout_constraint(get_input_layout_constraint)
.make_from_op_node(make_from_op_node)
.fallback();
} // namespace reshape



+ 2
- 2
imperative/src/impl/ops/tensor_manip.cpp View File

@@ -87,7 +87,7 @@ HostTensorND get_var_shape_host_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<DeviceTensorND> input_tensornds;
for (auto&& inp : inputs) {
input_tensornds.push_back(inp->dev_tensor());
input_tensornds.push_back(inp->dev_tensor(false));
}
SmallVector<DeviceTensorND> output_tensornds = {
{CompNode::default_cpu(), dtype::Int32()}};
@@ -100,7 +100,7 @@ HostTensorND get_var_shape_host_tensor(
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
return {Tensor::make(std::move(get_var_shape_host_tensor(def, inputs)))};
return {Tensor::make(get_var_shape_host_tensor(def, inputs))};
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(


Loading…
Cancel
Save