diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index 48140e17..16984c4d 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -17,7 +17,7 @@ namespace mgb { namespace imperative { -namespace { +namespace broadcast { std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { node_->cast_final_safe(); @@ -39,7 +39,7 @@ bool valid_broadcast(const TensorShape& src_shape, if (src_ndim > tar_ndim) { return false; } - size_t min_ndim = src_ndim < tar_ndim ? src_ndim : tar_ndim; + size_t min_ndim = src_ndim; for (size_t i = 0; i < min_ndim; ++i) { if (src_shape[src_ndim - i - 1] != 1 && src_shape[src_ndim - i - 1] != tar_shape[tar_ndim - i - 1]) { @@ -87,7 +87,70 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) .fallback(); -} // anonymous namespace +} // broadcast + +namespace reshape { + +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 2); + return opr::Reshape::make(inputs[0], inputs[1], op.param()); +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, + const SmallVector& inputs) { + auto&& op = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); + auto&& src = inputs[0]; + auto&& tshp = inputs[1]; + + TensorLayout out_layout = src.layout; + if (tshp.layout.ndim == 0 || tshp.value.empty()) { + out_layout.ndim = 0; + return {{{out_layout, src.comp_node}}, false}; + } + mgb_assert( + tshp.layout.ndim == 1, + "target shape of Broadcast expects ndim=1; got ndim=%lu actually", + tshp.layout.ndim); + + size_t target_ndim = tshp.layout.shape[0]; + out_layout.ndim = target_ndim; + auto* ptr = tshp.value.ptr(); + for (size_t i = 0; i < target_ndim; ++i) { + out_layout.shape[i] = ptr[i]; + } + + if (src.layout.ndim == 0) { + return {{{out_layout, src.comp_node}}, false}; + } + + if (op.axis != opr::Reshape::Param::INVALID_AXIS) { + mgb_assert(out_layout.shape[op.axis] == -1); + out_layout.shape[op.axis] = 1; + mgb_assert(src.layout.total_nr_elems() % out_layout.total_nr_elems() == 0, + "can not reshape from %s to %s", + src.layout.to_string().c_str(), + out_layout.to_string().c_str()); + out_layout.shape[op.axis] = src.layout.total_nr_elems() / out_layout.total_nr_elems(); + } else { + mgb_assert(src.layout.total_nr_elems() == out_layout.total_nr_elems(), + "can not reshape from %s to %s", + src.layout.to_string().c_str(), + out_layout.to_string().c_str()); + } + return {{{out_layout, src.comp_node}}, true}; +} + +OP_TRAIT_REG(Reshape, Reshape) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .fallback(); +} // reshape } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 15392a09..f56e2df8 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -548,19 +548,6 @@ OP_TRAIT_REG(Remap, Remap) .fallback(); }} // remap -namespace { namespace reshape { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { - auto&& op = static_cast(def); - mgb_assert(inputs.size() == 2); - return opr::Reshape::make(inputs[0], inputs[1], op.param()); -} -OP_TRAIT_REG(Reshape, Reshape) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // reshape - namespace { auto get_index( const VarNodeArray& inputs, size_t vidx,