|
|
@@ -17,7 +17,7 @@ |
|
|
|
namespace mgb { |
|
|
|
namespace imperative { |
|
|
|
|
|
|
|
namespace { |
|
|
|
namespace broadcast { |
|
|
|
|
|
|
|
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { |
|
|
|
node_->cast_final_safe<opr::Broadcast>(); |
|
|
@@ -39,7 +39,7 @@ bool valid_broadcast(const TensorShape& src_shape, |
|
|
|
if (src_ndim > tar_ndim) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
size_t min_ndim = src_ndim < tar_ndim ? src_ndim : tar_ndim; |
|
|
|
size_t min_ndim = src_ndim; |
|
|
|
for (size_t i = 0; i < min_ndim; ++i) { |
|
|
|
if (src_shape[src_ndim - i - 1] != 1 && |
|
|
|
src_shape[src_ndim - i - 1] != tar_shape[tar_ndim - i - 1]) { |
|
|
@@ -87,7 +87,70 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.infer_output_attrs_fallible(infer_output_attrs_fallible) |
|
|
|
.fallback(); |
|
|
|
} // anonymous namespace |
|
|
|
} // broadcast |
|
|
|
|
|
|
|
namespace reshape { |
|
|
|
|
|
|
|
auto apply_on_var_node( |
|
|
|
const OpDef& def, |
|
|
|
const VarNodeArray& inputs) { |
|
|
|
auto&& op = static_cast<const Reshape&>(def); |
|
|
|
mgb_assert(inputs.size() == 2); |
|
|
|
return opr::Reshape::make(inputs[0], inputs[1], op.param()); |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
const OpDef& def, |
|
|
|
const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
auto&& op = def.cast_final_safe<Reshape>(); |
|
|
|
size_t nr_inp = inputs.size(); |
|
|
|
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); |
|
|
|
auto&& src = inputs[0]; |
|
|
|
auto&& tshp = inputs[1]; |
|
|
|
|
|
|
|
TensorLayout out_layout = src.layout; |
|
|
|
if (tshp.layout.ndim == 0 || tshp.value.empty()) { |
|
|
|
out_layout.ndim = 0; |
|
|
|
return {{{out_layout, src.comp_node}}, false}; |
|
|
|
} |
|
|
|
mgb_assert( |
|
|
|
tshp.layout.ndim == 1, |
|
|
|
"target shape of Broadcast expects ndim=1; got ndim=%lu actually", |
|
|
|
tshp.layout.ndim); |
|
|
|
|
|
|
|
size_t target_ndim = tshp.layout.shape[0]; |
|
|
|
out_layout.ndim = target_ndim; |
|
|
|
auto* ptr = tshp.value.ptr<dt_int32>(); |
|
|
|
for (size_t i = 0; i < target_ndim; ++i) { |
|
|
|
out_layout.shape[i] = ptr[i]; |
|
|
|
} |
|
|
|
|
|
|
|
if (src.layout.ndim == 0) { |
|
|
|
return {{{out_layout, src.comp_node}}, false}; |
|
|
|
} |
|
|
|
|
|
|
|
if (op.axis != opr::Reshape::Param::INVALID_AXIS) { |
|
|
|
mgb_assert(out_layout.shape[op.axis] == -1); |
|
|
|
out_layout.shape[op.axis] = 1; |
|
|
|
mgb_assert(src.layout.total_nr_elems() % out_layout.total_nr_elems() == 0, |
|
|
|
"can not reshape from %s to %s", |
|
|
|
src.layout.to_string().c_str(), |
|
|
|
out_layout.to_string().c_str()); |
|
|
|
out_layout.shape[op.axis] = src.layout.total_nr_elems() / out_layout.total_nr_elems(); |
|
|
|
} else { |
|
|
|
mgb_assert(src.layout.total_nr_elems() == out_layout.total_nr_elems(), |
|
|
|
"can not reshape from %s to %s", |
|
|
|
src.layout.to_string().c_str(), |
|
|
|
out_layout.to_string().c_str()); |
|
|
|
} |
|
|
|
return {{{out_layout, src.comp_node}}, true}; |
|
|
|
} |
|
|
|
|
|
|
|
OP_TRAIT_REG(Reshape, Reshape) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.infer_output_attrs_fallible(infer_output_attrs_fallible) |
|
|
|
.fallback(); |
|
|
|
} // reshape |
|
|
|
|
|
|
|
} // namespace imperative |
|
|
|
} // namespace mgb |
|
|
|