Browse Source

fix(mgb): fix profile skip condition

GitOrigin-RevId: f196eabc98
HuaHua404-patch-1
Megvii Engine Team 3 years ago
parent
commit
5a35513856
2 changed files with 21 additions and 1 deletions
  1. +18
    -0
      imperative/src/impl/transformations/dtype_promote.cpp
  2. +3
    -1
      src/rdnn/impl/algo_chooser.cpp

+ 18
- 0
imperative/src/impl/transformations/dtype_promote.cpp View File

@@ -186,6 +186,15 @@ ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) {
ValueRefList matmul_rule(const OpDef& op, Span<ValueRef> inputs) { ValueRefList matmul_rule(const OpDef& op, Span<ValueRef> inputs) {
auto&& conv_op = const_cast<MatrixMul&>(op.cast_final_safe<MatrixMul>()); auto&& conv_op = const_cast<MatrixMul&>(op.cast_final_safe<MatrixMul>());
SmallVector<DType> dtypes = get_value_dtypes(inputs); SmallVector<DType> dtypes = get_value_dtypes(inputs);

// skip dtype promotion when inputs are quantized
if (dtypes[0].category() == megdnn::DTypeCategory::QUANTIZED) {
mgb_assert(
dtypes[0].category() == dtypes[1].category(),
"inputs of matmul should have same quantized dtype.");
return imperative::apply(op, inputs);
}

mgb::DType target_dtype; mgb::DType target_dtype;


if (DTypePromoteCfg::amp_dtype_autocast_enabled) { if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
@@ -212,6 +221,15 @@ ValueRefList batch_matmul_rule(const OpDef& op, Span<ValueRef> inputs) {
auto&& conv_op = auto&& conv_op =
const_cast<BatchedMatrixMul&>(op.cast_final_safe<BatchedMatrixMul>()); const_cast<BatchedMatrixMul&>(op.cast_final_safe<BatchedMatrixMul>());
SmallVector<DType> dtypes = get_value_dtypes(inputs); SmallVector<DType> dtypes = get_value_dtypes(inputs);

// skip dtype promotion when inputs are quantized
if (dtypes[0].category() == megdnn::DTypeCategory::QUANTIZED) {
mgb_assert(
dtypes[0].category() == dtypes[1].category(),
"inputs of batched matmul should have same quantized dtype.");
return imperative::apply(op, inputs);
}

mgb::DType target_dtype; mgb::DType target_dtype;


if (DTypePromoteCfg::amp_dtype_autocast_enabled) { if (DTypePromoteCfg::amp_dtype_autocast_enabled) {


+ 3
- 1
src/rdnn/impl/algo_chooser.cpp View File

@@ -600,7 +600,9 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::AlgoChooserHelp
auto&& megdnn_opr = opr::intl::create_megdnn_opr<_Opr>(m_cn); auto&& megdnn_opr = opr::intl::create_megdnn_opr<_Opr>(m_cn);
// skip different sub opr, for example: // skip different sub opr, for example:
// skip matmul algo when profiling convolution // skip matmul algo when profiling convolution
if (m_dnn_opr->get_opr_type() != megdnn_opr->get_opr_type())
if ((m_cn.device_type() == mgb::CompNode::DeviceType::CUDA ||
m_cn.device_type() == mgb::CompNode::DeviceType::ROCM) &&
m_dnn_opr->get_opr_type() != megdnn_opr->get_opr_type())
continue; continue;
megdnn_opr->param() = megdnn_opr->param() =
Algorithm::deserialize_read_pod<typename _Opr::Param>(_item.param); Algorithm::deserialize_read_pod<typename _Opr::Param>(_item.param);


Loading…
Cancel
Save