Browse Source

fix(mge/dnn): fix rng and topk oom in distributed training

GitOrigin-RevId: 9841d1219e
release-1.2
Megvii Engine Team 4 years ago
parent
commit
a240d558f1
2 changed files with 4 additions and 0 deletions
  1. +2
    -0
      src/opr/impl/misc.cpp
  2. +2
    -0
      src/opr/impl/rand.cpp

+ 2
- 0
src/opr/impl/misc.cpp View File

@@ -379,6 +379,8 @@ void TopK::init_output_static_infer_desc() {
}

auto infer_workspace = [this](TensorShape& dst, const InpVal& iv) {
// active comp_node for cuda launch kernel in get_workspace_in_bytes
comp_node().activate();
auto k = iv.val[3].value().ptr<int>()[0];
auto size = megdnn_opr()->get_workspace_in_bytes(
k, {iv.val[0].shape(), input(0)->dtype()},


+ 2
- 0
src/opr/impl/rand.cpp View File

@@ -60,6 +60,8 @@ cg::OperatorNodeBase::NodeProp* RNGOprBase::do_make_node_prop() const {

void RNGOprBase::ensure_megdnn_opr() {
if (!m_megdnn_opr || m_megdnn_opr.comp_node() != comp_node()) {
// activate comp_node for curandCreateGenerator in create_megdnn_opr
comp_node().activate();
m_megdnn_opr = create_megdnn_opr();
}
}


Loading…
Cancel
Save