Browse Source

fix(imperative/opr): fix apply_on_var_node for broadcast

GitOrigin-RevId: 686fff4f73
release-1.3
Megvii Engine Team 4 years ago
parent
commit
1edcfa19a8
2 changed files with 7 additions and 2 deletions
  1. +5
    -0
      imperative/python/test/unit/functional/test_tensor.py
  2. +2
    -2
      imperative/src/impl/ops/broadcast.cpp

+ 5
- 0
imperative/python/test/unit/functional/test_tensor.py View File

@@ -309,12 +309,17 @@ def test_broadcast():
output2_shape = (20, 10, 20)
data2 = np.random.random(input2_shape).astype(np.float32)

input3_shape = (10, 10)
output3_shape = (10, 10)
data3 = np.random.random(input3_shape).astype(np.float32)

def compare_fn(x, y):
assert x.shape[0] == y

cases = [
{"input": [data1, output1_shape], "output": output1_shape},
{"input": [data2, output2_shape], "output": output2_shape},
{"input": [data3, output3_shape], "output": output3_shape},
]
opr_test(cases, F.broadcast_to, compare_fn=compare_fn)



+ 2
- 2
imperative/src/impl/ops/broadcast.cpp View File

@@ -24,14 +24,14 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
return Broadcast::make();
}

cg::OperatorNodeBase* apply_on_var_node(
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Broadcast>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
OperatorNodeConfig config{op.make_name()};
return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr();
return opr::Broadcast::make(inputs[0], inputs[1], config);
}

bool valid_broadcast(const TensorShape& src_shape,


Loading…
Cancel
Save