Browse Source

feat(mgb/opr): add megbrain fake quant opr

GitOrigin-RevId: a858bce939
release-1.1
Megvii Engine Team 4 years ago
parent
commit
b60cc8ca3e
5 changed files with 92 additions and 0 deletions
  1. +4
    -0
      src/opr/impl/dnn/dnn.oprdecl
  2. +3
    -0
      src/opr/impl/dnn/dnn.sereg.h
  3. +40
    -0
      src/opr/impl/dnn/fake_quant.cpp
  4. +44
    -0
      src/opr/include/megbrain/opr/dnn/fake_quant.h
  5. +1
    -0
      src/serialization/impl/schema.fbs

+ 4
- 0
src/opr/impl/dnn/dnn.oprdecl View File

@@ -315,5 +315,9 @@ r"""
"""), """),
has_out_dtype=True) has_out_dtype=True)


decl_opr('FakeQuant',
inputs=[Doc('src','input tenosr'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor')],
params='FakeQuant')



# vim: ft=python # vim: ft=python

+ 3
- 0
src/opr/impl/dnn/dnn.sereg.h View File

@@ -18,6 +18,7 @@
#include "megbrain/opr/dnn/roi_align.h" #include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/fake_quant.h"


#include "megbrain/serialization/sereg.h" #include "megbrain/serialization/sereg.h"


@@ -423,6 +424,8 @@ namespace opr {
MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5);


MGB_SEREG_OPR(BatchConvBiasForward, 0); MGB_SEREG_OPR(BatchConvBiasForward, 0);
MGB_SEREG_OPR(FakeQuant, 3);
MGB_SEREG_OPR(FakeQuantBackward, 4);
} // namespace opr } // namespace opr






+ 40
- 0
src/opr/impl/dnn/fake_quant.cpp View File

@@ -0,0 +1,40 @@
/**
* \file src/opr/impl/dnn/fake_quant.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain/opr/dnn/fake_quant.h"
#include "../internal/megdnn_opr_wrapper.inl"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megbrain/opr/utility.h"

using namespace mgb;
using namespace opr;

MGB_DYN_TYPE_OBJ_FINAL_IMPL(FakeQuantForward);
MEGDNN_OPR_INIT3(FakeQuantForward, "fakequant_fwd");

#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(FakeQuantForward) {
if (wrt_idx == 0) {
// wrt src
SymbolVar grad =
FakeQuantBackward::make(out_grad[0], opr.input(0), opr.input(1),
opr.input(2), opr.param());
return grad.node();
} else {
return nullptr;
}
}
#endif

MGB_DYN_TYPE_OBJ_FINAL_IMPL(FakeQuantBackward);
MEGDNN_OPR_INIT4(FakeQuantBackward, "fakequant_bwd", 1, true);

+ 44
- 0
src/opr/include/megbrain/opr/dnn/fake_quant.h View File

@@ -0,0 +1,44 @@
/**
* \file src/opr/include/megbrain/opr/dnn/fake_quant.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megdnn/oprs.h"
namespace mgb {
namespace opr {
MGB_DEFINE_OPR_CLASS(FakeQuantForward,
intl::MegDNNOprWrapperFwd<megdnn::FakeQuantForward>) // {
public:
FakeQuantForward(VarNode* src, VarNode* scale, VarNode* zero_point,
const Param& param, const OperatorNodeConfig& config);

static SymbolVar make(SymbolVar src, SymbolVar scale, SymbolVar zero_point,
const Param& param = {},
const OperatorNodeConfig& config = {});
}; // namespace opr
using FakeQuant = FakeQuantForward;

MGB_DEFINE_OPR_CLASS(FakeQuantBackward,
intl::MegDNNOprWrapperBwd<megdnn::FakeQuantBackward>) // {
public:
FakeQuantBackward(VarNode* diff, VarNode* input, VarNode* scale,
VarNode* zero_point, const Param& param,
const OperatorNodeConfig& config);

static SymbolVar make(SymbolVar diff, SymbolVar input, SymbolVar scale,
SymbolVar zero_point, const Param& param = {},
const OperatorNodeConfig& config = {});

};

} // namespace mgb
} // namespace opr

+ 1
- 0
src/serialization/impl/schema.fbs View File

@@ -102,6 +102,7 @@ union OperatorParam {
param.AdaptivePooling = 70, param.AdaptivePooling = 70,
param.NvOf = 71, param.NvOf = 71,
param.DctChannelSelect = 72, param.DctChannelSelect = 72,
param.FakeQuant = 73,
} }


table Operator { table Operator {


Loading…
Cancel
Save