diff --git a/src/opr/impl/dnn/dnn.oprdecl b/src/opr/impl/dnn/dnn.oprdecl index 224d251b..04d8f81b 100644 --- a/src/opr/impl/dnn/dnn.oprdecl +++ b/src/opr/impl/dnn/dnn.oprdecl @@ -315,5 +315,9 @@ r""" """), has_out_dtype=True) +decl_opr('FakeQuant', + inputs=[Doc('src','input tenosr'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor')], + params='FakeQuant') + # vim: ft=python diff --git a/src/opr/impl/dnn/dnn.sereg.h b/src/opr/impl/dnn/dnn.sereg.h index 345bfd05..3aff3138 100644 --- a/src/opr/impl/dnn/dnn.sereg.h +++ b/src/opr/impl/dnn/dnn.sereg.h @@ -18,6 +18,7 @@ #include "megbrain/opr/dnn/roi_align.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/lrn.h" +#include "megbrain/opr/dnn/fake_quant.h" #include "megbrain/serialization/sereg.h" @@ -423,6 +424,8 @@ namespace opr { MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); MGB_SEREG_OPR(BatchConvBiasForward, 0); + MGB_SEREG_OPR(FakeQuant, 3); + MGB_SEREG_OPR(FakeQuantBackward, 4); } // namespace opr diff --git a/src/opr/impl/dnn/fake_quant.cpp b/src/opr/impl/dnn/fake_quant.cpp new file mode 100644 index 00000000..f910947e --- /dev/null +++ b/src/opr/impl/dnn/fake_quant.cpp @@ -0,0 +1,40 @@ +/** + * \file src/opr/impl/dnn/fake_quant.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/opr/dnn/fake_quant.h" +#include "../internal/megdnn_opr_wrapper.inl" +#include "megbrain/graph/grad_impl.h" +#include "megbrain/opr/internal/out_shape_by_sym_var.h" +#include "megbrain/opr/utility.h" + +using namespace mgb; +using namespace opr; + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(FakeQuantForward); +MEGDNN_OPR_INIT3(FakeQuantForward, "fakequant_fwd"); + +#ifdef MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(FakeQuantForward) { + if (wrt_idx == 0) { + // wrt src + SymbolVar grad = + FakeQuantBackward::make(out_grad[0], opr.input(0), opr.input(1), + opr.input(2), opr.param()); + return grad.node(); + } else { + return nullptr; + } +} +#endif + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(FakeQuantBackward); +MEGDNN_OPR_INIT4(FakeQuantBackward, "fakequant_bwd", 1, true); diff --git a/src/opr/include/megbrain/opr/dnn/fake_quant.h b/src/opr/include/megbrain/opr/dnn/fake_quant.h new file mode 100644 index 00000000..e73fe1ad --- /dev/null +++ b/src/opr/include/megbrain/opr/dnn/fake_quant.h @@ -0,0 +1,44 @@ +/** + * \file src/opr/include/megbrain/opr/dnn/fake_quant.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megdnn/oprs.h" +namespace mgb { +namespace opr { +MGB_DEFINE_OPR_CLASS(FakeQuantForward, + intl::MegDNNOprWrapperFwd) // { +public: +FakeQuantForward(VarNode* src, VarNode* scale, VarNode* zero_point, + const Param& param, const OperatorNodeConfig& config); + +static SymbolVar make(SymbolVar src, SymbolVar scale, SymbolVar zero_point, + const Param& param = {}, + const OperatorNodeConfig& config = {}); +}; // namespace opr +using FakeQuant = FakeQuantForward; + +MGB_DEFINE_OPR_CLASS(FakeQuantBackward, + intl::MegDNNOprWrapperBwd) // { +public: +FakeQuantBackward(VarNode* diff, VarNode* input, VarNode* scale, + VarNode* zero_point, const Param& param, + const OperatorNodeConfig& config); + +static SymbolVar make(SymbolVar diff, SymbolVar input, SymbolVar scale, + SymbolVar zero_point, const Param& param = {}, + const OperatorNodeConfig& config = {}); + +}; + +} // namespace mgb +} // namespace opr \ No newline at end of file diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index 83e1fa28..be47f360 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -102,6 +102,7 @@ union OperatorParam { param.AdaptivePooling = 70, param.NvOf = 71, param.DctChannelSelect = 72, + param.FakeQuant = 73, } table Operator {