diff --git a/src/opr/impl/dnn/roi_align.cpp b/src/opr/impl/dnn/roi_align.cpp index 3c29417f..e89ec8dc 100644 --- a/src/opr/impl/dnn/roi_align.cpp +++ b/src/opr/impl/dnn/roi_align.cpp @@ -42,9 +42,6 @@ SymbolVar ROIAlignForward::make(SymbolVar src, SymbolVar rois, #ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ROIAlignForward) { - if (out_grad[1]) { - return InvalidGrad::make(opr, wrt_idx); - } if (wrt_idx == 0) { // wrt src SymbolVar grad = diff --git a/src/opr/impl/dnn/roi_pooling.cpp b/src/opr/impl/dnn/roi_pooling.cpp index ab2801d5..7c2d3df9 100644 --- a/src/opr/impl/dnn/roi_pooling.cpp +++ b/src/opr/impl/dnn/roi_pooling.cpp @@ -86,7 +86,7 @@ size_t ROIPoolingForward::get_workspace_size_bytes( #ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ROIPoolingForward) { - if (out_grad[1] || wrt_idx == 2) { + if (wrt_idx == 2) { return InvalidGrad::make(opr, wrt_idx); } if (wrt_idx == 0) {