|
|
@@ -24,14 +24,14 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { |
|
|
|
return Broadcast::make(); |
|
|
|
} |
|
|
|
|
|
|
|
cg::OperatorNodeBase* apply_on_var_node( |
|
|
|
auto apply_on_var_node( |
|
|
|
const OpDef& def, |
|
|
|
const VarNodeArray& inputs) { |
|
|
|
auto&& op = def.cast_final_safe<Broadcast>(); |
|
|
|
size_t nr_inp = inputs.size(); |
|
|
|
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); |
|
|
|
OperatorNodeConfig config{op.make_name()}; |
|
|
|
return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr(); |
|
|
|
return opr::Broadcast::make(inputs[0], inputs[1], config); |
|
|
|
} |
|
|
|
|
|
|
|
bool valid_broadcast(const TensorShape& src_shape, |
|
|
|