diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index 439d4cc6..abb54b37 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -215,8 +215,9 @@ SymbolVar NvOf::make( void NvOf::scn_do_execute() { auto input_shape = this->input()[0]->shape(); + std::vector t_shape; for (size_t i = 0; i < 5; i++) { - vshape.push_back(input_shape[i]); + t_shape.push_back(input_shape[i]); } auto c = this->comp_node(); //! comp_node may init on CUDA or CPU, eg: lar with --cpu @@ -232,7 +233,8 @@ void NvOf::scn_do_execute() { //! create NvOF engine at same device id of comp_node, can not get //! comp_node device id, when NvOf:NvOf, so init at scn_do_execute std::lock_guard lock(m_lock); - if (init_flag == false) { + if (init_flag == false || vshape != t_shape) { + vshape = t_shape; //! nvof sdk do not imp p2p copy, so init nvof engine on the same //! device with mgb comp_node nv_flow_extractor = std::make_shared(