diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index 46e2b541..b8ea6fdf 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -810,7 +810,7 @@ void CollectiveComm::init_output_static_infer_desc() { }; auto get_shape_from_server = [this](TensorShape& dest, const InpVal&) { - if (!m_enable_shape_infer) { + if (!m_enable_shape_infer && !owner_graph()->options().imperative_proxy_graph) { return false; }