GitOrigin-RevId: 686fff4f73
release-1.3
@@ -309,12 +309,17 @@ def test_broadcast(): | |||||
output2_shape = (20, 10, 20) | output2_shape = (20, 10, 20) | ||||
data2 = np.random.random(input2_shape).astype(np.float32) | data2 = np.random.random(input2_shape).astype(np.float32) | ||||
input3_shape = (10, 10) | |||||
output3_shape = (10, 10) | |||||
data3 = np.random.random(input3_shape).astype(np.float32) | |||||
def compare_fn(x, y): | def compare_fn(x, y): | ||||
assert x.shape[0] == y | assert x.shape[0] == y | ||||
cases = [ | cases = [ | ||||
{"input": [data1, output1_shape], "output": output1_shape}, | {"input": [data1, output1_shape], "output": output1_shape}, | ||||
{"input": [data2, output2_shape], "output": output2_shape}, | {"input": [data2, output2_shape], "output": output2_shape}, | ||||
{"input": [data3, output3_shape], "output": output3_shape}, | |||||
] | ] | ||||
opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | ||||
@@ -24,14 +24,14 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
return Broadcast::make(); | return Broadcast::make(); | ||||
} | } | ||||
cg::OperatorNodeBase* apply_on_var_node( | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = def.cast_final_safe<Broadcast>(); | auto&& op = def.cast_final_safe<Broadcast>(); | ||||
size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr(); | |||||
return opr::Broadcast::make(inputs[0], inputs[1], config); | |||||
} | } | ||||
bool valid_broadcast(const TensorShape& src_shape, | bool valid_broadcast(const TensorShape& src_shape, | ||||