Browse Source

feat(imperative/opr-mm): add broadcast

GitOrigin-RevId: 83640255c7
release-0.6
Megvii Engine Team 4 years ago
parent
commit
d3d9018f8d
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      src/opr-mm/impl/collective_comm.cpp

+ 1
- 1
src/opr-mm/impl/collective_comm.cpp View File

@@ -810,7 +810,7 @@ void CollectiveComm::init_output_static_infer_desc() {
};

auto get_shape_from_server = [this](TensorShape& dest, const InpVal&) {
if (!m_enable_shape_infer) {
if (!m_enable_shape_infer && !owner_graph()->options().imperative_proxy_graph) {
return false;
}



Loading…
Cancel
Save