@@ -315,5 +315,9 @@ r""" | |||||
"""), | """), | ||||
has_out_dtype=True) | has_out_dtype=True) | ||||
decl_opr('FakeQuant', | |||||
inputs=[Doc('src','input tenosr'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor')], | |||||
params='FakeQuant') | |||||
# vim: ft=python | # vim: ft=python |
@@ -18,6 +18,7 @@ | |||||
#include "megbrain/opr/dnn/roi_align.h" | #include "megbrain/opr/dnn/roi_align.h" | ||||
#include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
#include "megbrain/opr/dnn/lrn.h" | #include "megbrain/opr/dnn/lrn.h" | ||||
#include "megbrain/opr/dnn/fake_quant.h" | |||||
#include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
@@ -423,6 +424,8 @@ namespace opr { | |||||
MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); | MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); | ||||
MGB_SEREG_OPR(BatchConvBiasForward, 0); | MGB_SEREG_OPR(BatchConvBiasForward, 0); | ||||
MGB_SEREG_OPR(FakeQuant, 3); | |||||
MGB_SEREG_OPR(FakeQuantBackward, 4); | |||||
} // namespace opr | } // namespace opr | ||||
@@ -0,0 +1,40 @@ | |||||
/** | |||||
* \file src/opr/impl/dnn/fake_quant.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/opr/dnn/fake_quant.h" | |||||
#include "../internal/megdnn_opr_wrapper.inl" | |||||
#include "megbrain/graph/grad_impl.h" | |||||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||||
#include "megbrain/opr/utility.h" | |||||
using namespace mgb; | |||||
using namespace opr; | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(FakeQuantForward); | |||||
MEGDNN_OPR_INIT3(FakeQuantForward, "fakequant_fwd"); | |||||
#ifdef MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(FakeQuantForward) { | |||||
if (wrt_idx == 0) { | |||||
// wrt src | |||||
SymbolVar grad = | |||||
FakeQuantBackward::make(out_grad[0], opr.input(0), opr.input(1), | |||||
opr.input(2), opr.param()); | |||||
return grad.node(); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
#endif | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(FakeQuantBackward); | |||||
MEGDNN_OPR_INIT4(FakeQuantBackward, "fakequant_bwd", 1, true); |
@@ -0,0 +1,44 @@ | |||||
/** | |||||
* \file src/opr/include/megbrain/opr/dnn/fake_quant.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
#include "megdnn/oprs.h" | |||||
namespace mgb { | |||||
namespace opr { | |||||
MGB_DEFINE_OPR_CLASS(FakeQuantForward, | |||||
intl::MegDNNOprWrapperFwd<megdnn::FakeQuantForward>) // { | |||||
public: | |||||
FakeQuantForward(VarNode* src, VarNode* scale, VarNode* zero_point, | |||||
const Param& param, const OperatorNodeConfig& config); | |||||
static SymbolVar make(SymbolVar src, SymbolVar scale, SymbolVar zero_point, | |||||
const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
}; // namespace opr | |||||
using FakeQuant = FakeQuantForward; | |||||
MGB_DEFINE_OPR_CLASS(FakeQuantBackward, | |||||
intl::MegDNNOprWrapperBwd<megdnn::FakeQuantBackward>) // { | |||||
public: | |||||
FakeQuantBackward(VarNode* diff, VarNode* input, VarNode* scale, | |||||
VarNode* zero_point, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVar make(SymbolVar diff, SymbolVar input, SymbolVar scale, | |||||
SymbolVar zero_point, const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
}; | |||||
} // namespace mgb | |||||
} // namespace opr |
@@ -102,6 +102,7 @@ union OperatorParam { | |||||
param.AdaptivePooling = 70, | param.AdaptivePooling = 70, | ||||
param.NvOf = 71, | param.NvOf = 71, | ||||
param.DctChannelSelect = 72, | param.DctChannelSelect = 72, | ||||
param.FakeQuant = 73, | |||||
} | } | ||||
table Operator { | table Operator { | ||||