@@ -1741,6 +1741,67 @@ protected: | |||
const TensorLayout& grad_s, size_t workspace_in_bytes); | |||
}; | |||
class LSQBase : public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(LSQBase, OperatorBase); | |||
DEF_OPR_PARAM(LSQ); | |||
protected: | |||
void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output); | |||
void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale, | |||
const TensorLayout& zero_point, | |||
const TensorLayout& grad_scale, | |||
const TensorLayout& output); | |||
}; | |||
class LSQForward : public LSQBase { | |||
DEF_OPR_IMPL(LSQForward, LSQBase, 4, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||
_megdnn_tensor_in zero_point, | |||
_megdnn_tensor_in grad_scale, _megdnn_tensor_out output, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& input, const TensorLayout& scale, | |||
const TensorLayout& zero_point, | |||
const TensorLayout& grad_scale, TensorLayout& output); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& input, | |||
const TensorLayout& scale, | |||
const TensorLayout& zero_point, | |||
const TensorLayout& grad_scale, | |||
const TensorLayout& output) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& input, const TensorLayout& scale, | |||
const TensorLayout& zero_point, | |||
const TensorLayout& grad_scale, const TensorLayout& output, | |||
size_t workspace_in_bytes); | |||
}; | |||
using LSQ = LSQForward; | |||
class LSQBackward : public LSQBase { | |||
DEF_OPR_IMPL(LSQBackward, LSQBase, 5, 2); | |||
public: | |||
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point, | |||
_megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, | |||
_megdnn_tensor_out grad_s, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||
const TensorLayout& input, | |||
const TensorLayout& scale, | |||
const TensorLayout& zero_point, | |||
const TensorLayout& grad_scale, | |||
const TensorLayout& grad_x, | |||
const TensorLayout& grad_s) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& diff, const TensorLayout& input, | |||
const TensorLayout& scale, const TensorLayout& zero_point, | |||
const TensorLayout& grad_scale, const TensorLayout& grad_x, | |||
const TensorLayout& grad_s, size_t workspace_in_bytes); | |||
}; | |||
} // namespace megdnn | |||
#include "megdnn/internal/opr_header_epilogue.h" | |||
@@ -1124,3 +1124,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||
add_fields('int32', 'qmin', '-2147483648'). | |||
add_fields('int32', 'qmax', '2147483647') | |||
) | |||
(pdef('LSQ'). | |||
add_fields('int32', 'qmin', '-2147483648'). | |||
add_fields('int32', 'qmax', '2147483647') | |||
) | |||
@@ -37,6 +37,7 @@ namespace megdnn { | |||
megdnn_assert(size, "uninitialized ElemwiseOpParamN"); | |||
} | |||
template struct ElemwiseOpParamN<7>; | |||
template struct ElemwiseOpParamN<6>; | |||
template struct ElemwiseOpParamN<5>; | |||
template struct ElemwiseOpParamN<4>; | |||
@@ -208,7 +208,9 @@ private: | |||
cb(FakeQuantBackward) \ | |||
cb(TQTForward) \ | |||
cb(TQTBackward) \ | |||
cb(CheckHasInf) | |||
cb(CheckHasInf) \ | |||
cb(LSQForward) \ | |||
cb(LSQBackward) | |||
/*! | |||
* \brief specialize HandleImpl::create_operator for a single opr type; | |||
@@ -0,0 +1,69 @@ | |||
/** | |||
* \file dnn/src/common/lsq.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void LSQBase::deduce_layout_fwd(const TensorLayout& input, | |||
TensorLayout& output) { | |||
output = TensorLayout(input, input.dtype); | |||
} | |||
void LSQBase::check_layout_fwd(const TensorLayout& input, | |||
const TensorLayout& scale, | |||
const TensorLayout& zero_point, | |||
const TensorLayout& grad_scale, | |||
const TensorLayout& output) { | |||
megdnn_assert(input.dtype == dtype::Float32()); | |||
megdnn_assert(scale.dtype == dtype::Float32()); | |||
megdnn_assert(zero_point.dtype == dtype::Float32()); | |||
megdnn_assert(grad_scale.dtype == dtype::Float32()); | |||
TensorLayout expected; | |||
deduce_layout_fwd(input, expected); | |||
megdnn_assert_eq_layout(expected, output); | |||
} | |||
void LSQForward::deduce_layout(const TensorLayout& input, | |||
const TensorLayout& /* scale */, | |||
const TensorLayout& /*zero_point*/, | |||
const TensorLayout& /*grad_scale*/, | |||
TensorLayout& output) { | |||
deduce_layout_fwd(input, output); | |||
} | |||
void LSQForward::check_exec(const TensorLayout& input, | |||
const TensorLayout& scale, | |||
const TensorLayout& zero_point, | |||
const TensorLayout& grad_scale, | |||
const TensorLayout& output, | |||
size_t workspace_in_bytes) { | |||
check_layout_fwd(input, scale, zero_point, grad_scale, output); | |||
auto required_workspace_space = get_workspace_in_bytes( | |||
input, scale, zero_point, grad_scale, output); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_space); | |||
} | |||
void LSQBackward::check_exec( | |||
const TensorLayout& diff, const TensorLayout& input, | |||
const TensorLayout& scale, const TensorLayout& zero_point, | |||
const TensorLayout& grad_scale, const TensorLayout& grad_x, | |||
const TensorLayout& grad_s, size_t workspace_in_bytes) { | |||
megdnn_assert_eq_shape(diff, input); | |||
megdnn_assert_eq_shape(grad_x, input); | |||
auto required_worspace_space = get_workspace_in_bytes( | |||
diff, input, scale, zero_point, grad_scale, grad_x, grad_s); | |||
megdnn_assert(workspace_in_bytes >= required_worspace_space); | |||
} | |||
} // namespace megdnn |
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
@@ -121,6 +122,8 @@ DEF(UniformRNG, 1, true, true); | |||
DEF(GaussianRNG, 1, true, true); | |||
DEF(ChecksumForward, 1, true, false); | |||
DEF(CheckHasInf, 2, true, true); | |||
DEF(LSQForward, 5, true, true); | |||
DEF(LSQBackward, 7, true, false); | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -947,6 +947,119 @@ struct OpCallerUniform<Op, 5, PVis> { | |||
} | |||
}; | |||
//! specialization for arity == 6 | |||
template <class Op, class PVis> | |||
struct OpCallerUniform<Op, 6, PVis> { | |||
Op op; | |||
PVis par[6]; | |||
static const uint32_t packed_size = PVis::packed_size; | |||
devfunc void thread_init(uint32_t idx) { | |||
idx = idx * packed_size; | |||
par[0].thread_init(idx); | |||
par[1].thread_init(idx); | |||
par[2].thread_init(idx); | |||
par[3].thread_init(idx); | |||
par[4].thread_init(idx); | |||
par[5].thread_init(idx); | |||
} | |||
devfunc void on(uint32_t idx) { | |||
idx = idx * packed_size; | |||
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx), | |||
par[4].at(idx), par[5].at(idx)); | |||
} | |||
devfunc void on(uint32_t idx, uint32_t remain) { | |||
idx = idx * packed_size; | |||
if (remain >= packed_size) { | |||
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), | |||
par[3].at(idx), par[4].at(idx), par[5].at(idx)); | |||
} else { | |||
auto ptr0 = par[0].ptr(); | |||
auto ptr1 = par[1].ptr(); | |||
auto ptr2 = par[2].ptr(); | |||
auto ptr3 = par[3].ptr(); | |||
auto ptr4 = par[4].ptr(); | |||
auto ptr5 = par[5].ptr(); | |||
for (int i = 0; i < remain; i++) { | |||
op(idx + i, ptr0[par[0].offset(idx + i)], | |||
ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)], | |||
ptr3[par[3].offset(idx + i)], ptr4[par[4].offset(idx + i)], | |||
ptr5[par[5].offset(idx + i)]); | |||
} | |||
} | |||
} | |||
devfunc void next() { | |||
par[0].next(); | |||
par[1].next(); | |||
par[2].next(); | |||
par[3].next(); | |||
par[4].next(); | |||
par[5].next(); | |||
} | |||
}; | |||
//! specialization for arity == 7 | |||
template <class Op, class PVis> | |||
struct OpCallerUniform<Op, 7, PVis> { | |||
Op op; | |||
PVis par[7]; | |||
static const uint32_t packed_size = PVis::packed_size; | |||
devfunc void thread_init(uint32_t idx) { | |||
idx = idx * packed_size; | |||
par[0].thread_init(idx); | |||
par[1].thread_init(idx); | |||
par[2].thread_init(idx); | |||
par[3].thread_init(idx); | |||
par[4].thread_init(idx); | |||
par[5].thread_init(idx); | |||
par[6].thread_init(idx); | |||
} | |||
devfunc void on(uint32_t idx) { | |||
idx = idx * packed_size; | |||
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx), | |||
par[4].at(idx), par[5].at(idx), par[6].at(idx)); | |||
} | |||
devfunc void on(uint32_t idx, uint32_t remain) { | |||
idx = idx * packed_size; | |||
if (remain >= packed_size) { | |||
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), | |||
par[3].at(idx), par[4].at(idx), par[5].at(idx), par[6].at(idx)); | |||
} else { | |||
auto ptr0 = par[0].ptr(); | |||
auto ptr1 = par[1].ptr(); | |||
auto ptr2 = par[2].ptr(); | |||
auto ptr3 = par[3].ptr(); | |||
auto ptr4 = par[4].ptr(); | |||
auto ptr5 = par[5].ptr(); | |||
auto ptr6 = par[6].ptr(); | |||
for (int i = 0; i < remain; i++) { | |||
op(idx + i, ptr0[par[0].offset(idx + i)], | |||
ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)], | |||
ptr3[par[3].offset(idx + i)], ptr4[par[4].offset(idx + i)], | |||
ptr5[par[5].offset(idx + i)], ptr6[par[6].offset(idx + i)]); | |||
} | |||
} | |||
} | |||
devfunc void next() { | |||
par[0].next(); | |||
par[1].next(); | |||
par[2].next(); | |||
par[3].next(); | |||
par[4].next(); | |||
par[5].next(); | |||
par[6].next(); | |||
} | |||
}; | |||
/*! | |||
* \brief call binary (i.e. arity == 2) operator with different param | |||
* visitors | |||
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/common/handle_impl.h" | |||
@@ -15,6 +16,7 @@ | |||
#include "src/cuda/add_update/opr_impl.h" | |||
#include "src/cuda/argmxx/opr_impl.h" | |||
#include "src/cuda/argsort/opr_impl.h" | |||
#include "src/cuda/batch_conv_bias/opr_impl.h" | |||
#include "src/cuda/batch_normalization/opr_impl.h" | |||
#include "src/cuda/batched_matrix_mul/opr_impl.h" | |||
#include "src/cuda/check_has_inf/opr_impl.h" | |||
@@ -35,6 +37,7 @@ | |||
#include "src/cuda/elemwise/opr_impl.h" | |||
#include "src/cuda/elemwise_multi_type/opr_impl.h" | |||
#include "src/cuda/eye/opr_impl.h" | |||
#include "src/cuda/fake_quant/opr_impl.h" | |||
#include "src/cuda/flip/opr_impl.h" | |||
#include "src/cuda/gaussian_blur/opr_impl.h" | |||
#include "src/cuda/group_local/opr_impl.h" | |||
@@ -45,6 +48,7 @@ | |||
#include "src/cuda/local/opr_impl.h" | |||
#include "src/cuda/local_share/opr_impl.h" | |||
#include "src/cuda/lrn/opr_impl.h" | |||
#include "src/cuda/lsq/opr_impl.h" | |||
#include "src/cuda/mask_conv/opr_impl.h" | |||
#include "src/cuda/matrix_inverse/opr_impl.h" | |||
#include "src/cuda/matrix_mul/opr_impl.h" | |||
@@ -56,9 +60,11 @@ | |||
#include "src/cuda/reduce/opr_impl.h" | |||
#include "src/cuda/relayout/opr_impl.h" | |||
#include "src/cuda/relayout_format/opr_impl.h" | |||
#include "src/cuda/remap/opr_impl.h" | |||
#include "src/cuda/repeat/opr_impl.h" | |||
#include "src/cuda/resize/opr_impl.h" | |||
#include "src/cuda/rng/opr_impl.h" | |||
#include "src/cuda/roi_align/opr_impl.h" | |||
#include "src/cuda/roi_copy/opr_impl.h" | |||
#include "src/cuda/roi_pooling/opr_impl.h" | |||
#include "src/cuda/rotate/opr_impl.h" | |||
@@ -70,16 +76,11 @@ | |||
#include "src/cuda/tensor_remap/opr_impl.h" | |||
#include "src/cuda/tile/opr_impl.h" | |||
#include "src/cuda/topk/opr_impl.h" | |||
#include "src/cuda/tqt/opr_impl.h" | |||
#include "src/cuda/transpose/opr_impl.h" | |||
#include "src/cuda/type_cvt/opr_impl.h" | |||
#include "src/cuda/warp_affine/opr_impl.h" | |||
#include "src/cuda/warp_perspective/opr_impl.h" | |||
#include "src/cuda/local_share/opr_impl.h" | |||
#include "src/cuda/roi_align/opr_impl.h" | |||
#include "src/cuda/batch_conv_bias/opr_impl.h" | |||
#include "src/cuda/remap/opr_impl.h" | |||
#include "src/cuda/fake_quant/opr_impl.h" | |||
#include "src/cuda/tqt/opr_impl.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
@@ -0,0 +1,30 @@ | |||
/** | |||
* \file dnn/src/cuda/lsq/kern.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "./kern.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
#define cb(_dtype) \ | |||
INST_RUN_ELEMWISE(LSQKernOp<DTypeTrait<_dtype>::ctype>, \ | |||
DTypeTrait<_dtype>::ctype, 3); \ | |||
INST_RUN_ELEMWISE(LSQBwdKernOp<DTypeTrait<_dtype>::ctype>, \ | |||
DTypeTrait<_dtype>::ctype, 3); \ | |||
INST_RUN_ELEMWISE(LSQKernOpNonContig<DTypeTrait<_dtype>::ctype>, \ | |||
DTypeTrait<_dtype>::ctype, 5); \ | |||
INST_RUN_ELEMWISE(LSQBwdKernOpNonContig<DTypeTrait<_dtype>::ctype>, \ | |||
DTypeTrait<_dtype>::ctype, 7); | |||
cb(megdnn::dtype::Float32) | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,126 @@ | |||
/** | |||
* \file dnn/src/cuda/lsq/kern.cuh | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/cuda/elemwise_helper.cuh" | |||
#include "src/cuda/utils.cuh" | |||
#if MEGDNN_CC_HOST | |||
#include "megdnn/oprs.h" | |||
#endif | |||
namespace megdnn { | |||
namespace cuda { | |||
template <typename ctype> | |||
struct LSQKernOp { | |||
ctype* input; | |||
ctype* output; | |||
ctype qmin, qmax; | |||
__device__ void operator()(uint32_t idx, ctype scale, ctype zero_point, | |||
ctype grad_scale) { | |||
ctype x = input[idx] / scale + zero_point; | |||
x = fmaxf(fminf(x, qmax), qmin); | |||
x = round(x); | |||
output[idx] = (x - zero_point) * scale; | |||
} | |||
#if MEGDNN_CC_HOST | |||
LSQKernOp(const TensorND& input, const TensorND& output, | |||
const LSQ::Param& param) | |||
: input{input.ptr<ctype>()}, | |||
output{output.ptr<ctype>()}, | |||
qmin(param.qmin), | |||
qmax(param.qmax) {} | |||
#endif | |||
}; | |||
template <typename ctype> | |||
struct LSQBwdKernOp { | |||
ctype* diff; | |||
ctype* input; | |||
ctype* grad_x; | |||
ctype* grad_s; | |||
ctype qmin, qmax; | |||
__device__ void operator()(uint32_t idx, ctype scale, ctype zero_point, | |||
ctype grad_scale) { | |||
ctype x = input[idx] / scale + zero_point; | |||
bool ind_small = x < qmin; | |||
bool ind_big = x > qmax; | |||
bool ind_middle = ind_small ^ ind_big; | |||
ind_middle = !ind_middle; | |||
grad_s[idx] = ind_small * qmin + ind_big * qmax + | |||
ind_middle * (-x + round(x)); | |||
grad_s[idx] = grad_s[idx] * grad_scale * diff[idx]; | |||
grad_x[idx] = ind_middle * diff[idx]; | |||
} | |||
#if MEGDNN_CC_HOST | |||
LSQBwdKernOp(const TensorND& diff, const TensorND& input, | |||
const TensorND& grad_x, const TensorND& grad_s, | |||
const LSQ::Param& param) | |||
: diff{diff.ptr<ctype>()}, | |||
input{input.ptr<ctype>()}, | |||
grad_x{grad_x.ptr<ctype>()}, | |||
grad_s{grad_s.ptr<ctype>()}, | |||
qmin(param.qmin), | |||
qmax(param.qmax) {} | |||
#endif | |||
}; | |||
template <typename ctype> | |||
struct LSQKernOpNonContig { | |||
ctype qmin; | |||
ctype qmax; | |||
__device__ void operator()(uint32_t, ctype& output, ctype& input, | |||
ctype& scale, ctype& zero_point, | |||
ctype grad_scale) { | |||
ctype x = input / scale + zero_point; | |||
x = fmaxf(fminf(x, qmax), qmin); | |||
x = round(x); | |||
output = (x - zero_point) * scale; | |||
} | |||
#if MEGDNN_CC_HOST | |||
LSQKernOpNonContig(const LSQ::Param& param) | |||
: qmin(param.qmin), qmax(param.qmax) {} | |||
#endif | |||
}; | |||
template <typename ctype> | |||
struct LSQBwdKernOpNonContig { | |||
ctype qmin; | |||
ctype qmax; | |||
__device__ void operator()(uint32_t, ctype& grad_x, ctype& grad_s, | |||
ctype& diff, ctype& input, ctype& scale, | |||
ctype& zero_point, ctype grad_scale) { | |||
ctype x = input / scale + zero_point; | |||
bool ind_small = x < qmin; | |||
bool ind_big = x > qmax; | |||
bool ind_middle = ind_small ^ ind_big; | |||
ind_middle = !ind_middle; | |||
grad_s = ind_small * qmin + ind_big * qmax + | |||
ind_middle * (-x + round(x)); | |||
grad_s = grad_s * grad_scale * diff; | |||
grad_x = ind_middle * diff; | |||
} | |||
#if MEGDNN_CC_HOST | |||
LSQBwdKernOpNonContig(const LSQ::Param& param) | |||
: qmin(param.qmin), qmax(param.qmax) {} | |||
#endif | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,151 @@ | |||
/** | |||
* \file dnn/src/cuda/lsq/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "./opr_impl.h" | |||
#include "./kern.cuh" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
void LSQForwardImpl::exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||
_megdnn_tensor_in zero_point, | |||
_megdnn_tensor_in grad_scale, | |||
_megdnn_tensor_out output, | |||
_megdnn_workspace workspace) { | |||
check_exec(input.layout, scale.layout, zero_point.layout, grad_scale.layout, | |||
output.layout, workspace.size); | |||
if (!input.layout.is_contiguous() || !output.layout.is_contiguous()) | |||
return exec_noncontig(input, scale, zero_point, grad_scale, output); | |||
ElemwiseOpParamN<3> ele_param; | |||
ele_param[0] = scale; | |||
ele_param[0].layout = ele_param[0].layout.broadcast(input.layout); | |||
ele_param[1] = zero_point; | |||
ele_param[1].layout = ele_param[1].layout.broadcast(input.layout); | |||
ele_param[2] = grad_scale; | |||
ele_param[2].layout = ele_param[2].layout.broadcast(input.layout); | |||
ele_param.init_from_given_tensor(); | |||
auto m_param = param(); | |||
auto stream = cuda_stream(handle()); | |||
#define cb(DType) \ | |||
if (input.layout.dtype == DType()) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
run_elemwise<LSQKernOp<T>, T, 3>(ele_param, stream, \ | |||
{input, output, m_param}); \ | |||
return; \ | |||
} | |||
cb(megdnn::dtype::Float32) | |||
#undef cb | |||
} | |||
void LSQForwardImpl::exec_noncontig(_megdnn_tensor_in input, | |||
_megdnn_tensor_in scale, | |||
_megdnn_tensor_in zero_point, | |||
_megdnn_tensor_in grad_scale, | |||
_megdnn_tensor_out output) { | |||
ElemwiseOpParamN<5> ele_param; | |||
ele_param[0] = output; | |||
ele_param[1] = input; | |||
ele_param[2] = scale; | |||
ele_param[2].layout = ele_param[2].layout.broadcast(input.layout); | |||
ele_param[3] = zero_point; | |||
ele_param[3].layout = ele_param[3].layout.broadcast(input.layout); | |||
ele_param[4] = grad_scale; | |||
ele_param[4].layout = ele_param[4].layout.broadcast(input.layout); | |||
ele_param.init_from_given_tensor(); | |||
auto m_param = param(); | |||
auto stream = cuda_stream(handle()); | |||
#define cb(DType) \ | |||
if (input.layout.dtype == DType()) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
run_elemwise<LSQKernOpNonContig<T>, T, 5>(ele_param, stream, \ | |||
{m_param}); \ | |||
return; \ | |||
} | |||
cb(megdnn::dtype::Float32) | |||
#undef cb | |||
} | |||
void LSQBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||
_megdnn_tensor_in scale, | |||
_megdnn_tensor_in zero_point, | |||
_megdnn_tensor_in grad_scale, | |||
_megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s, | |||
_megdnn_workspace workspace) { | |||
check_exec(diff.layout, input.layout, scale.layout, zero_point.layout, | |||
grad_scale.layout, grad_x.layout, grad_s.layout, workspace.size); | |||
if (!input.layout.is_contiguous() || !diff.layout.is_contiguous() || | |||
!grad_x.layout.is_contiguous() || !grad_s.layout.is_contiguous()) | |||
return exec_noncontig(diff, input, scale, zero_point, grad_scale, | |||
grad_x, grad_s); | |||
ElemwiseOpParamN<3> ele_param; | |||
ele_param[0] = scale; | |||
ele_param[0].layout = ele_param[0].layout.broadcast(input.layout); | |||
ele_param[1] = zero_point; | |||
ele_param[1].layout = ele_param[1].layout.broadcast(input.layout); | |||
ele_param[2] = grad_scale; | |||
ele_param[2].layout = ele_param[2].layout.broadcast(input.layout); | |||
ele_param.init_from_given_tensor(); | |||
auto m_param = param(); | |||
auto stream = cuda_stream(handle()); | |||
#define cb(DType) \ | |||
if (grad_x.layout.dtype == DType()) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
run_elemwise<LSQBwdKernOp<T>, T, 3>( \ | |||
ele_param, stream, {diff, input, grad_x, grad_s, m_param}); \ | |||
return; \ | |||
} | |||
cb(megdnn::dtype::Float32) | |||
#undef cb | |||
} | |||
void LSQBackwardImpl::exec_noncontig(_megdnn_tensor_in diff, | |||
_megdnn_tensor_in input, | |||
_megdnn_tensor_in scale, | |||
_megdnn_tensor_in zero_point, | |||
_megdnn_tensor_in grad_scale, | |||
_megdnn_tensor_out grad_x, | |||
_megdnn_tensor_out grad_s) { | |||
ElemwiseOpParamN<7> ele_param; | |||
ele_param[0] = grad_x; | |||
ele_param[1] = grad_s; | |||
ele_param[2] = diff; | |||
ele_param[3] = input; | |||
ele_param[4] = scale; | |||
ele_param[4].layout = ele_param[4].layout.broadcast(input.layout); | |||
ele_param[5] = zero_point; | |||
ele_param[5].layout = ele_param[5].layout.broadcast(input.layout); | |||
ele_param[6] = grad_scale; | |||
ele_param[6].layout = ele_param[6].layout.broadcast(input.layout); | |||
ele_param.init_from_given_tensor(); | |||
auto m_param = param(); | |||
auto stream = cuda_stream(handle()); | |||
#define cb(DType) \ | |||
if (input.layout.dtype == DType()) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
run_elemwise<LSQBwdKernOpNonContig<T>, T, 7>(ele_param, stream, \ | |||
{m_param}); \ | |||
return; \ | |||
} | |||
cb(megdnn::dtype::Float32) | |||
#undef cb | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,65 @@ | |||
/** | |||
* \file dnn/src/cuda/lsq/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/cuda/utils.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class LSQForwardImpl final : public LSQForward { | |||
public: | |||
using LSQForward::LSQForward; | |||
void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||
_megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, | |||
_megdnn_tensor_out output, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, /* input */ | |||
const TensorLayout&, /* scale */ | |||
const TensorLayout&, /* zero_point */ | |||
const TensorLayout&, /* grad_scale */ | |||
const TensorLayout& /* output */) override { | |||
return 0; | |||
} | |||
private: | |||
void exec_noncontig(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||
_megdnn_tensor_in zero_point, | |||
_megdnn_tensor_in grad_scale, | |||
_megdnn_tensor_out output); | |||
}; | |||
class LSQBackwardImpl final : public LSQBackward { | |||
public: | |||
using LSQBackward::LSQBackward; | |||
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point, | |||
_megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, | |||
_megdnn_tensor_out grad_s, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout& /* diff */, | |||
const TensorLayout& /* input */, | |||
const TensorLayout& /* scale */, | |||
const TensorLayout& /* zero_point */, | |||
const TensorLayout& /* grad_scale */, | |||
const TensorLayout& /* grad_x */, | |||
const TensorLayout& /* grad_s */) override { | |||
return 0; | |||
} | |||
private: | |||
void exec_noncontig(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point, | |||
_megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, | |||
_megdnn_tensor_out grad_s); | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -50,6 +50,7 @@ | |||
#include "src/naive/local/opr_impl.h" | |||
#include "src/naive/local_share/opr_impl.h" | |||
#include "src/naive/lrn/opr_impl.h" | |||
#include "src/naive/lsq/opr_impl.h" | |||
#include "src/naive/mask_conv/opr_impl.h" | |||
#include "src/naive/matrix_inverse/opr_impl.h" | |||
#include "src/naive/matrix_mul/opr_impl.h" | |||
@@ -0,0 +1,141 @@ | |||
/** | |||
* \file dnn/src/naive/lsq/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/naive/lsq/opr_impl.h" | |||
#include <cmath> | |||
#include "megdnn/tensor_iter.h" | |||
#include "src/common/elemwise_helper.cuh" | |||
#include "src/common/utils.h" | |||
#include "src/naive/handle.h" | |||
namespace { | |||
using namespace megdnn; | |||
template <typename T> | |||
void forward_impl(const ElemwiseOpParamN<5> src, float qmin, float qmax) { | |||
auto inp = tensor_iter_valonly<T>(src[0]).begin(); | |||
auto out = tensor_iter_valonly<T>(src[1]).begin(); | |||
auto scale = tensor_iter_valonly<T>(src[2]).begin(); | |||
auto zero_point = tensor_iter_valonly<T>(src[3]).begin(); | |||
auto grad_scale = tensor_iter_valonly<T>(src[4]).begin(); | |||
size_t total = src[0].layout.total_nr_elems(); | |||
for (size_t i = 0; i < total; ++i) { | |||
T x = (*inp) / (*scale) + (*zero_point); | |||
x = x <= qmin ? qmin : x; | |||
x = x >= qmax ? qmax : x; | |||
x = round(x); | |||
*out = (x - (*zero_point)) * (*scale); | |||
++inp; | |||
++out; | |||
++scale; | |||
++zero_point; | |||
++grad_scale; | |||
} | |||
} | |||
template <typename T> | |||
void backward_impl(const ElemwiseOpParamN<7> src, float qmin, float qmax) { | |||
auto diff = tensor_iter_valonly<T>(src[0]).begin(); | |||
auto input = tensor_iter_valonly<T>(src[1]).begin(); | |||
auto scale = tensor_iter_valonly<T>(src[2]).begin(); | |||
auto zero_point = tensor_iter_valonly<T>(src[3]).begin(); | |||
auto grad_scale = tensor_iter_valonly<T>(src[4]).begin(); | |||
auto grad_x = tensor_iter_valonly<T>(src[5]).begin(); | |||
auto grad_s = tensor_iter_valonly<T>(src[6]).begin(); | |||
size_t total = src[0].layout.total_nr_elems(); | |||
for (size_t i = 0; i < total; ++i) { | |||
T x = (*input) / (*scale) + (*zero_point); | |||
bool ind_small = x < qmin; | |||
bool ind_big = x > qmax; | |||
bool ind_middle = ind_small ^ ind_big; | |||
ind_middle = !ind_middle; | |||
*grad_s = ind_small * qmin + ind_big * qmax + | |||
ind_middle * (-x + round(x)); | |||
*grad_s = (*grad_s) * (*grad_scale) * (*diff); | |||
*grad_x = ind_middle * (*diff); | |||
++diff; | |||
++input; | |||
++scale; | |||
++zero_point; | |||
++grad_scale; | |||
++grad_x; | |||
++grad_s; | |||
} | |||
} | |||
} // namespace | |||
namespace megdnn { | |||
namespace naive { | |||
void LSQForwardImpl::exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||
_megdnn_tensor_in zero_point, | |||
_megdnn_tensor_in grad_scale, | |||
_megdnn_tensor_out output, | |||
_megdnn_workspace workspace) { | |||
check_exec(input.layout, scale.layout, zero_point.layout, grad_scale.layout, | |||
output.layout, workspace.size); | |||
ElemwiseOpParamN<5> src; | |||
src[0] = input; | |||
src[1] = output; | |||
src[2] = scale; | |||
src[2].layout = src[2].layout.broadcast(input.layout); | |||
src[3] = zero_point; | |||
src[3].layout = src[3].layout.broadcast(input.layout); | |||
src[4] = grad_scale; | |||
src[4].layout = src[4].layout.broadcast(input.layout); | |||
#define cb(DType) \ | |||
if (input.layout.dtype == DType()) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
forward_impl<T>(src, param().qmin, param().qmax)); \ | |||
return; \ | |||
} | |||
cb(dtype::Float32) | |||
#undef cb | |||
} | |||
void LSQBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||
_megdnn_tensor_in scale, | |||
_megdnn_tensor_in zero_point, | |||
_megdnn_tensor_in grad_scale, | |||
_megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s, | |||
_megdnn_workspace workspace) { | |||
check_exec(diff.layout, input.layout, scale.layout, zero_point.layout, | |||
grad_scale.layout, grad_x.layout, grad_s.layout, workspace.size); | |||
ElemwiseOpParamN<7> src; | |||
src[0] = diff; | |||
src[1] = input; | |||
src[2] = scale; | |||
src[2].layout = src[2].layout.broadcast(input.layout); | |||
src[3] = zero_point; | |||
src[3].layout = src[3].layout.broadcast(input.layout); | |||
src[4] = grad_scale; | |||
src[4].layout = src[4].layout.broadcast(input.layout); | |||
src[5] = grad_x; | |||
src[6] = grad_s; | |||
#define cb(DType) \ | |||
if (diff.layout.dtype == DType() && grad_x.layout.dtype == DType() && \ | |||
input.layout.dtype == DType()) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
backward_impl<T>(src, param().qmin, param().qmax)); \ | |||
return; \ | |||
} | |||
cb(dtype::Float32) | |||
#undef cb | |||
} | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,53 @@ | |||
/** | |||
* \file dnn/src/naive/lsq/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class LSQForwardImpl final : public LSQForward { | |||
public: | |||
using LSQForward::LSQForward; | |||
void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||
_megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, | |||
_megdnn_tensor_out output, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout& /* input */, | |||
const TensorLayout& /* scale */, | |||
const TensorLayout& /* zero_point */, | |||
const TensorLayout& /* grad_scale */, | |||
const TensorLayout& /* output */) override { | |||
return 0; | |||
} | |||
}; | |||
class LSQBackwardImpl final : public LSQBackward { | |||
public: | |||
using LSQBackward::LSQBackward; | |||
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point, | |||
_megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, | |||
_megdnn_tensor_out grad_s, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout& /* diff */, | |||
const TensorLayout& /* input */, | |||
const TensorLayout& /* scale */, | |||
const TensorLayout& /* zero_point */, | |||
const TensorLayout& /* grad_scale */, | |||
const TensorLayout& /* grad_x */, | |||
const TensorLayout& /* grad_s */) override { | |||
return 0; | |||
} | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,53 @@ | |||
/** | |||
* \file dnn/test/common/lsq.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/opr_param_defs.h" | |||
namespace megdnn { | |||
namespace test { | |||
namespace lsq { | |||
struct TestArg { | |||
param::LSQ param; | |||
TensorShape ishape; | |||
TensorShape scale_shape; | |||
TensorShape zeropoint_shape; | |||
TensorShape gradscale_shape; | |||
TestArg(param::LSQ param, TensorShape ishape, TensorShape scale_shape, | |||
TensorShape zeropoint_shape, TensorShape gradscale_shape) | |||
: param(param), | |||
ishape(ishape), | |||
scale_shape(scale_shape), | |||
zeropoint_shape(zeropoint_shape), | |||
gradscale_shape(gradscale_shape) {} | |||
}; | |||
inline std::vector<TestArg> get_args() { | |||
std::vector<TestArg> args; | |||
param::LSQ cur_param; | |||
cur_param.qmin = -127; | |||
cur_param.qmax = 127; | |||
for (size_t i = 10; i < 30; i += 2) { | |||
args.emplace_back(cur_param, TensorShape{10, 64, i, i}, TensorShape{1}, | |||
TensorShape{1}, TensorShape{1}); | |||
} | |||
return args; | |||
} | |||
} // namespace lsq | |||
} // namespace test | |||
} // namespace megdnn |
@@ -0,0 +1,110 @@ | |||
/** | |||
* \file dnn/test/cuda/lsq.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "test/common/lsq.h" | |||
#include "megdnn/oprs.h" | |||
#include "test/common/checker.h" | |||
#include "test/cuda/fixture.h" | |||
namespace megdnn { | |||
namespace test { | |||
using namespace lsq; | |||
TEST_F(CUDA, LSQ) { | |||
std::vector<TestArg> args = get_args(); | |||
auto dtype = dtype::Float32(); | |||
for (auto&& arg : args) { | |||
auto param = arg.param; | |||
auto ishape = arg.ishape; | |||
auto scale_shape = arg.scale_shape; | |||
auto zeropoint_shape = arg.zeropoint_shape; | |||
auto gradscale_shape = arg.gradscale_shape; | |||
Checker<LSQForward> checker(handle_cuda()); | |||
checker.set_param(param) | |||
.set_dtype(0, dtype) | |||
.set_dtype(1, dtype) | |||
.set_dtype(2, dtype) | |||
.set_dtype(3, dtype) | |||
.set_dtype(4, dtype) | |||
.execs({ishape, scale_shape, zeropoint_shape, gradscale_shape, | |||
ishape}); | |||
} | |||
// test noncontiguous layout | |||
for (auto&& arg : args) { | |||
auto param = arg.param; | |||
auto ishape = arg.ishape; | |||
auto sshape = arg.scale_shape; | |||
auto zeropoint_shape = arg.zeropoint_shape; | |||
auto gradscale_shape = arg.gradscale_shape; | |||
Checker<LSQForward> checker(handle_cuda()); | |||
TensorLayout ilayout( | |||
ishape, | |||
{(long int)(ishape[1] * ishape[2] * ishape[3] * 2), | |||
(long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1}, | |||
dtype::Float32()); | |||
checker.set_param(param).execl({ilayout, | |||
{sshape, dtype::Float32()}, | |||
{zeropoint_shape, dtype::Float32()}, | |||
{gradscale_shape, dtype::Float32()}, | |||
ilayout}); | |||
} | |||
} | |||
TEST_F(CUDA, LSQ_BACKWARD) { | |||
std::vector<TestArg> args = get_args(); | |||
auto dtype = dtype::Float32(); | |||
for (auto&& arg : args) { | |||
auto param = arg.param; | |||
auto ishape = arg.ishape; | |||
auto scale_shape = arg.scale_shape; | |||
auto zeropoint_shape = arg.zeropoint_shape; | |||
auto gradscale_shape = arg.gradscale_shape; | |||
Checker<LSQBackward> checker(handle_cuda()); | |||
checker.set_param(param) | |||
.set_dtype(0, dtype) | |||
.set_dtype(1, dtype) | |||
.set_dtype(2, dtype) | |||
.set_dtype(3, dtype) | |||
.set_dtype(4, dtype) | |||
.set_dtype(5, dtype) | |||
.set_dtype(6, dtype) | |||
.execs({ishape, ishape, scale_shape, zeropoint_shape, | |||
gradscale_shape, ishape, ishape}); | |||
} | |||
// test noncontiguous layout | |||
for (auto&& arg : args) { | |||
auto param = arg.param; | |||
auto ishape = arg.ishape; | |||
auto sshape = arg.scale_shape; | |||
auto zeropoint_shape = arg.zeropoint_shape; | |||
auto gradscale_shape = arg.gradscale_shape; | |||
Checker<LSQBackward> checker(handle_cuda()); | |||
TensorLayout ilayout( | |||
ishape, | |||
{(long int)(ishape[1] * ishape[2] * ishape[3] * 2), | |||
(long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1}, | |||
dtype::Float32()); | |||
checker.set_param(param).execl({ilayout, | |||
ilayout, | |||
{sshape, dtype::Float32()}, | |||
{zeropoint_shape, dtype::Float32()}, | |||
{gradscale_shape, dtype::Float32()}, | |||
ilayout, | |||
ilayout}); | |||
} | |||
} | |||
} // namespace test | |||
} // namespace megdnn |
@@ -0,0 +1,45 @@ | |||
/** | |||
* \file dnn/test/naive/sliding_window_transpose.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "test/naive/fixture.h" | |||
#include "megdnn/oprs/nn.h" | |||
#include "test/common/checker.h" | |||
using namespace megdnn; | |||
using namespace test; | |||
TEST_F(NAIVE, LSQ_FORWARD) { | |||
Checker<LSQ> checker(handle(), /* check_dispatch */ false); | |||
param::LSQ param; | |||
param.qmin = -127; | |||
param.qmax = 127; | |||
TensorND input = | |||
TensorValue({2, 2, 2, 2}, dtype::Float32(), | |||
{0, 1, 3, 4, 1, 2, 4, 5, 3, 4, 6, 7, 4, 5, 7, 8}); | |||
TensorND scale_shape = TensorValue({1}, dtype::Float32(), {2}); | |||
TensorND zero_point = TensorValue({1}, dtype::Float32(), {1}); | |||
TensorND grad_scale = TensorValue({1}, dtype::Float32(), {0.5}); | |||
TensorND output = | |||
TensorValue({2, 2, 2, 2}, dtype::Float32(), | |||
{0, 2, 4, 4, 2, 2, 4, 6, 4, 4, 6, 8, 4, 6, 8, 8}); | |||
checker.set_param(param).exect( | |||
Testcase{input, scale_shape, zero_point, grad_scale, {}}, | |||
Testcase{{}, {}, {}, {}, output}); | |||
} |
@@ -6,7 +6,7 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .fake_quant import TQT, FakeQuantize | |||
from .fake_quant import LSQ, TQT, FakeQuantize | |||
from .observer import ( | |||
ExponentialMovingAverageObserver, | |||
HistogramObserver, | |||
@@ -12,13 +12,15 @@ from .. import functional as F | |||
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes | |||
from ..logger import get_logger | |||
from ..module import Module | |||
from ..tensor import Parameter | |||
from ..tensor import Parameter, Tensor | |||
from .utils import ( | |||
LSQParams, | |||
QParams, | |||
QParamsModuleMixin, | |||
QuantMode, | |||
create_qparams, | |||
fake_quant_tensor, | |||
lsq_forward, | |||
tqt_forward, | |||
) | |||
@@ -117,3 +119,58 @@ class FakeQuantize(_FakeQuantize): | |||
qparams.dtype_meta, self.dtype | |||
) | |||
return fake_quant_tensor(inp, qparams) | |||
class LSQ(_FakeQuantize, QParamsModuleMixin): | |||
r""" | |||
LSQ: https://arxiv.org/pdf/1902.08153.pdf Estimating and scaling the | |||
task loss gradient at each weight and activation layer's quantizer step size | |||
:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target | |||
quantization dtype of input. | |||
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``. | |||
:param eps:a small value to avoid division by zero. Default: 1e-5 | |||
""" | |||
def init( | |||
self, | |||
dtype: Union[str, QuantDtypeMeta], | |||
enable: bool = True, | |||
eps: float = 1e-5, | |||
**kwargs | |||
): | |||
super().__init__(dtype=dtype, enable=enable, **kwargs) | |||
self.eps = Tensor(eps, dtype="float32") | |||
self.step_size = Parameter(1.0, dtype="float32") | |||
def set_qparams(self, qparams: LSQParams): | |||
self.mode = qparams.mode | |||
if qparams.mode == QuantMode.ASYMMERTIC: | |||
self.zero_point = qparams.zero_point | |||
else: | |||
self.zero_point = Tensor([0.0], dtype="float32") | |||
if qparams.scale is None: | |||
raise AssertionError("Can not get an initialized scale") | |||
init_step_size = qparams.scale | |||
if init_step_size < self.eps: | |||
init_step_size = 0 | |||
else: | |||
init_step_size = init_step_size - self.eps | |||
self.step_size = Parameter(init_step_size, dtype="float32") | |||
self.grad_scale = qparams.grad_scale | |||
def fake_quant_forward(self, inp, qparams: LSQParams = None): | |||
step_size = F.abs(self.step_size) + self.eps | |||
return lsq_forward( | |||
self.qmin, self.qmax, inp, step_size, self.zero_point, self.grad_scale | |||
) | |||
def get_qparams(self): | |||
return LSQParams( | |||
mode=self.mode, | |||
dtype_meta=self.dtype, | |||
scale=F.abs(self.step_size.detach()) + self.eps, | |||
zero_point=self.zero_point, | |||
grad_scale=self.grad_scale, | |||
) |
@@ -43,6 +43,16 @@ def tqt_forward(qmin, qmax, inp, scale): | |||
return output | |||
def lsq_forward(qmin, qmax, inp, step_size, zero_point=None, scale_grad=None): | |||
if zero_point is None: | |||
zero_point = Tensor([0.0], dtype=np.float32) | |||
if scale_grad is None: | |||
scale_grad = Tensor([1.0], dtype=np.float32) | |||
op = builtin.LSQ(qmin=qmin, qmax=qmax) | |||
(output,) = apply(op, inp, step_size, zero_point, scale_grad) | |||
return output | |||
def register_method_to_class(cls): | |||
def decorator(func): | |||
@wraps(func) | |||
@@ -105,6 +115,47 @@ class QParams: | |||
return "QParams({})".format(content) | |||
class LSQParams: | |||
""" | |||
To standardize LSQ's qparams format. If custom | |||
qparams is needed, inherit this class and add custom ``__slots__``. | |||
""" | |||
__slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale" | |||
def __init__( | |||
self, | |||
mode: QuantMode, | |||
dtype_meta: QuantDtypeMeta, | |||
scale: Tensor, | |||
zero_point: Tensor, | |||
grad_scale: Tensor, | |||
): | |||
self.mode = mode | |||
self.dtype_meta = dtype_meta | |||
self.scale = scale | |||
self.zero_point = zero_point | |||
self.grad_scale = grad_scale | |||
def update(self, lsqparams: "LSQParams"): | |||
for key in self.__slots__: | |||
setattr(self, key, getattr(lsqparams, key)) | |||
def __eq__(self, other): | |||
if len(self.__slots__) != len(other.__slots__): | |||
return False | |||
for key in self.__slots__: | |||
if not hasattr(other, key) or getattr(self, key) != getattr(other, key): | |||
return False | |||
return True | |||
def __repr__(self): | |||
content = ", ".join( | |||
["{}={}".format(key, getattr(self, key)) for key in self.__slots__] | |||
) | |||
return "LSQParams({})".format(content) | |||
class QParamsModuleMixin(abc.ABC): | |||
def get_quantized_dtype(self): | |||
qparams = self.get_qparams() | |||
@@ -10,6 +10,7 @@ import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.functional as F | |||
from megengine import tensor | |||
from megengine.core.autodiff.grad import Function, Grad | |||
from megengine.core.tensor.dtype import QuantDtypeMeta | |||
@@ -19,6 +20,7 @@ from megengine.quantization.utils import ( | |||
QuantMode, | |||
create_qparams, | |||
fake_quant_tensor, | |||
lsq_forward, | |||
tqt_forward, | |||
) | |||
@@ -150,3 +152,78 @@ def test_fakequant(): | |||
zero_point = tensor(1.0 * np.ones((1, 32, 1, 1)), dtype=np.float32) | |||
scale = tensor(4.0 * np.ones((1, 32, 1, 1)), dtype=np.float32) | |||
run(zero_point, scale) | |||
class LSQ_numpy: | |||
def __init__(self, lowerbound, upperbound): | |||
super().__init__() | |||
self.lowerbound = lowerbound | |||
self.upperbound = upperbound | |||
def forward(self, inp, scale, zero_point, grad_scale): | |||
inp_scaled = inp / scale + zero_point | |||
inp_clipped = np.maximum( | |||
np.minimum(inp_scaled, self.upperbound), self.lowerbound | |||
) | |||
inp_rounded = np.floor(inp_clipped + 0.5) | |||
inp_flq = (inp_rounded - zero_point) * scale | |||
self.saved_tensors = (inp_scaled, inp_rounded, scale, grad_scale) | |||
return inp_flq | |||
def backward(self, grad_inp_flq): | |||
(inp_scaled, inp_rounded, scale, grad_scale) = self.saved_tensors | |||
ind_small = inp_scaled < self.lowerbound | |||
ind_big = inp_scaled > self.upperbound | |||
ind_middle = np.logical_xor(ind_small, ind_big) | |||
ind_middle = np.abs(ind_middle - 1) | |||
grad_s = ( | |||
ind_small * self.lowerbound | |||
+ ind_big * self.upperbound | |||
+ ind_middle * (-inp_scaled + inp_rounded) | |||
) | |||
grad_s = grad_s * grad_scale * grad_inp_flq | |||
grad_s = grad_s.sum() | |||
grad_inp = grad_inp_flq * ind_middle | |||
return grad_inp, grad_s | |||
def test_lsq(): | |||
def preprocess(scale, eps): | |||
scale = np.array([0]) if scale < eps else scale - eps | |||
return np.abs(scale) + eps | |||
g = [] | |||
def cb(grad): | |||
g.append(grad) | |||
x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32") | |||
s = np.random.rand(1) | |||
eps = np.array([1e-5], dtype="float32") | |||
s = preprocess(s, eps) | |||
zero_point = np.array([1.0], dtype="float32") | |||
grad_s = np.array([2.0], dtype="float32") | |||
g_y = np.ones(shape=(1, 2, 3, 4), dtype="float32") | |||
n = LSQ_numpy(-127, 127) | |||
y_np = n.forward(x, s, zero_point, grad_s) | |||
g_x_np, g_s_np = n.backward(g_y) | |||
x = mge.tensor(x, dtype="float32") | |||
s = mge.tensor(s, dtype="float32") | |||
zero_point = mge.tensor(zero_point, dtype="float32") | |||
grad_s = mge.tensor(grad_s, dtype="float32") | |||
g_y = mge.tensor(g_y, dtype="float32") | |||
grad = Grad().wrt(x, s, callback=cb) | |||
y = lsq_forward(-127, 127, x, s, zero_point, grad_s) | |||
grad(y, g_y) | |||
g_x, g_s = g | |||
np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-7, atol=1e-7) | |||
np.testing.assert_allclose(g_x.numpy(), g_x_np, rtol=1e-7, atol=1e-7) | |||
np.testing.assert_allclose(g_s.numpy(), g_s_np, rtol=5e-7, atol=5e-7) |
@@ -6,23 +6,26 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
// FIXME: split this file into separate files for each specialized op | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/blas.h" | |||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/dnn/correlation.h" | |||
#include "megbrain/opr/dnn/fake_quant.h" | |||
#include "megbrain/opr/dnn/tqt.h" | |||
#include "megbrain/opr/dnn/pooling.h" | |||
#include "megbrain/opr/dnn/images2neibs.h" | |||
#include "megbrain/opr/dnn/local.h" | |||
#include "megbrain/opr/dnn/lsq.h" | |||
#include "megbrain/opr/dnn/pooling.h" | |||
#include "megbrain/opr/dnn/roi_align.h" | |||
#include "megbrain/opr/dnn/correlation.h" | |||
#include "megbrain/opr/dnn/roi_pooling.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/blas.h" | |||
#include "megbrain/opr/dnn/tqt.h" | |||
#include "megbrain/opr/imgproc.h" | |||
#include "megbrain/opr/indexing.h" | |||
#include "megbrain/opr/io.h" | |||
@@ -32,40 +35,38 @@ | |||
#include "megbrain/opr/tensor_gen.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "megbrain/opr/dnn/images2neibs.h" | |||
#include "../op_trait.h" | |||
namespace mgb::imperative { | |||
namespace { namespace dimshuffle { | |||
namespace { | |||
namespace dimshuffle { | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
auto* node = &node_->cast_final_safe<opr::Dimshuffle>(); | |||
std::vector<int> pattern(node->param().pattern_len); | |||
for (size_t i = 0; i < node->param().pattern_len; ++ i) { | |||
for (size_t i = 0; i < node->param().pattern_len; ++i) { | |||
pattern[i] = node->param().pattern[i]; | |||
} | |||
return Dimshuffle::make(pattern); | |||
} | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& ds = static_cast<const Dimshuffle&>(def); | |||
OperatorNodeConfig config{ds.make_name()}; | |||
return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config); | |||
} | |||
OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // dimshuffle | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace dimshuffle | |||
} // namespace | |||
namespace { namespace add_axis { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace add_axis { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& add_axis = static_cast<const AddAxis&>(def); | |||
using Desc = opr::AxisAddRemove::AxisDesc; | |||
std::vector<Desc> param; | |||
@@ -76,15 +77,13 @@ auto apply_on_var_node( | |||
return opr::AxisAddRemove::make(inputs[0], param, config); | |||
} | |||
OP_TRAIT_REG(AddAxis, AddAxis) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // add_axis | |||
OP_TRAIT_REG(AddAxis, AddAxis).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace add_axis | |||
} // namespace | |||
namespace { namespace remove_axis { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace remove_axis { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& remove_axis = static_cast<const RemoveAxis&>(def); | |||
using Desc = opr::AxisAddRemove::AxisDesc; | |||
std::vector<Desc> param; | |||
@@ -96,36 +95,35 @@ auto apply_on_var_node( | |||
} | |||
OP_TRAIT_REG(RemoveAxis, RemoveAxis) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // remove_axis | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace remove_axis | |||
} // namespace | |||
namespace { namespace top_k { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace top_k { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& topk = static_cast<const TopK&>(def); | |||
OperatorNodeConfig config{topk.make_name()}; | |||
return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0] | |||
.node()->owner_opr(); | |||
.node() | |||
->owner_opr(); | |||
} | |||
OP_TRAIT_REG(TopK, TopK) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // top_k | |||
OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace top_k | |||
} // namespace | |||
namespace { namespace reduce { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace reduce { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& reduce = static_cast<const Reduce&>(def); | |||
OperatorNodeConfig config{reduce.make_name()}; | |||
if (inputs.size() > 1) { | |||
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); | |||
} else { | |||
return opr::Reduce::make( | |||
inputs[0], reduce.param(), (cg::VarNode*)nullptr, config); | |||
return opr::Reduce::make(inputs[0], reduce.param(), | |||
(cg::VarNode*)nullptr, config); | |||
} | |||
} | |||
@@ -135,86 +133,92 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
} | |||
OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // reduce | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace reduce | |||
} // namespace | |||
namespace { namespace adaptive_pooling { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace adaptive_pooling { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& pool = static_cast<const AdaptivePooling&>(def); | |||
OperatorNodeConfig config{pool.make_name()}; | |||
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config); | |||
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), | |||
config); | |||
} | |||
OP_TRAIT_REG(AdaptivePooling, AdaptivePooling) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // adaptive_pooling | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace adaptive_pooling | |||
} // namespace | |||
namespace { namespace conv_bias { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace conv_bias { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const ConvBias&>(def); | |||
cg::OperatorNodeConfig config{conv.dtype}; | |||
config.name(conv.make_name()); | |||
if (inputs.size() == 2) { | |||
return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), | |||
conv.policy(), config); | |||
} else if (inputs.size() == 3) { | |||
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], | |||
conv.param(), conv.policy(), config); | |||
} else if (inputs.size() == 4) { | |||
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config); | |||
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], | |||
conv.param(), conv.policy(), config); | |||
} | |||
mgb_assert(0); | |||
} | |||
OP_TRAIT_REG(ConvBias, ConvBias) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // conv_bias | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace conv_bias | |||
} // namespace | |||
namespace { namespace batch_conv_bias { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace batch_conv_bias { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const BatchConvBias&>(def); | |||
cg::OperatorNodeConfig config{conv.dtype}; | |||
config.name(conv.make_name()); | |||
if (inputs.size() == 2) { | |||
return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), | |||
conv.policy(), config); | |||
} else if (inputs.size() == 3) { | |||
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], | |||
conv.param(), conv.policy(), config); | |||
} else if (inputs.size() == 4) { | |||
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config); | |||
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], | |||
inputs[3], conv.param(), conv.policy(), | |||
config); | |||
} | |||
mgb_assert(0); | |||
} | |||
OP_TRAIT_REG(BatchConvBias, BatchConvBias) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // batch_conv_bias | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace batch_conv_bias | |||
} // namespace | |||
namespace { namespace pooling { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace pooling { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& pool = static_cast<const Pooling&>(def); | |||
OperatorNodeConfig config{pool.make_name()}; | |||
return opr::Pooling::make(inputs[0], pool.param(), config); | |||
} | |||
OP_TRAIT_REG(Pooling, Pooling) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // pooling | |||
OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace pooling | |||
} // namespace | |||
namespace { namespace matrix_mul { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace matrix_mul { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& matmul = static_cast<const MatrixMul&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{matmul.make_name()}; | |||
@@ -222,14 +226,14 @@ auto apply_on_var_node( | |||
matmul.policy(), config); | |||
} | |||
OP_TRAIT_REG(MatrixMul, MatrixMul) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // matrix_mul | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace matrix_mul | |||
} // namespace | |||
namespace { namespace batched_matrix_mul { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace batched_matrix_mul { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& matmul = static_cast<const BatchedMatrixMul&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{matmul.make_name()}; | |||
@@ -237,166 +241,155 @@ auto apply_on_var_node( | |||
matmul.policy(), config); | |||
} | |||
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // batched_matrix_mul | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace batched_matrix_mul | |||
} // namespace | |||
namespace { namespace dot { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace dot { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = def.cast_final_safe<Dot>(); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::Dot::make(inputs[0], inputs[1], config); | |||
} | |||
OP_TRAIT_REG(Dot, Dot) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // dot | |||
OP_TRAIT_REG(Dot, Dot).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace dot | |||
} // namespace | |||
namespace { namespace argsort { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace argsort { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& argsort = static_cast<const Argsort&>(def); | |||
OperatorNodeConfig config{argsort.make_name()}; | |||
return opr::Argsort::make(inputs[0], argsort.param(), config); | |||
} | |||
OP_TRAIT_REG(Argsort, Argsort) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // argsort | |||
OP_TRAIT_REG(Argsort, Argsort).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace argsort | |||
} // namespace | |||
namespace { namespace argmax { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace argmax { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& argmax = static_cast<const Argmax&>(def); | |||
OperatorNodeConfig config{argmax.make_name()}; | |||
return opr::Argmax::make(inputs[0], argmax.param(), config); | |||
} | |||
OP_TRAIT_REG(Argmax, Argmax) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // argmax | |||
OP_TRAIT_REG(Argmax, Argmax).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace argmax | |||
} // namespace | |||
namespace { namespace argmin { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace argmin { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& argmin = static_cast<const Argmin&>(def); | |||
OperatorNodeConfig config{argmin.make_name()}; | |||
return opr::Argmin::make(inputs[0], argmin.param(), config); | |||
} | |||
OP_TRAIT_REG(Argmin, Argmin) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // argmin | |||
OP_TRAIT_REG(Argmin, Argmin).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace argmin | |||
} // namespace | |||
namespace { namespace warp_perspective { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace warp_perspective { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& warp = static_cast<const WarpPerspective&>(def); | |||
OperatorNodeConfig config{warp.make_name()}; | |||
if (inputs.size() == 3) { | |||
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param(), config); | |||
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], | |||
warp.param(), config); | |||
} else { | |||
mgb_assert(inputs.size() == 4); | |||
return opr::WarpPerspective::make( | |||
inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config); | |||
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], | |||
inputs[3], warp.param(), config); | |||
} | |||
} | |||
OP_TRAIT_REG(WarpPerspective, WarpPerspective) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // warp_perspective | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace warp_perspective | |||
} // namespace | |||
namespace { namespace group_local { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace group_local { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& local = static_cast<const GroupLocal&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{local.make_name()}; | |||
return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config); | |||
} | |||
OP_TRAIT_REG(GroupLocal, GroupLocal) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // group_local | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace group_local | |||
} // namespace | |||
namespace { namespace indexing_one_hot { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace indexing_one_hot { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const IndexingOneHot&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config); | |||
} | |||
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // indexing_one_hot | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace indexing_one_hot | |||
} // namespace | |||
namespace { namespace indexing_set_one_hot { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace indexing_set_one_hot { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const IndexingSetOneHot&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], | |||
op.param(), config); | |||
} | |||
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // indexing_set_one_hot | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace indexing_set_one_hot | |||
} // namespace | |||
namespace { namespace typecvt { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace typecvt { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const TypeCvt&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::TypeCvt::make(inputs[0], op.dtype, config); | |||
} | |||
OP_TRAIT_REG(TypeCvt, TypeCvt) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // typecvt | |||
OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace typecvt | |||
} // namespace | |||
namespace { namespace concat { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace concat { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Concat&>(def); | |||
cg::OperatorNodeConfig config{op.comp_node}; | |||
config.name(op.make_name()); | |||
return opr::Concat::make(inputs, op.axis, config); | |||
} | |||
OP_TRAIT_REG(Concat, Concat) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // concat | |||
OP_TRAIT_REG(Concat, Concat).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace concat | |||
} // namespace | |||
namespace { namespace copy { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace copy { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Copy&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
cg::OperatorNodeConfig config{op.comp_node}; | |||
config.name(op.make_name()); | |||
return opr::Copy::make(inputs[0], config); | |||
} | |||
OP_TRAIT_REG(Copy, Copy) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // copy | |||
OP_TRAIT_REG(Copy, Copy).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace copy | |||
} // namespace | |||
namespace { namespace assert_equal { | |||
auto apply_on_var_node( | |||
@@ -408,81 +401,81 @@ auto apply_on_var_node( | |||
} else { | |||
// workaround for MiniGraph, which only allow one opr in the graph | |||
mgb_assert(inputs.size() == 3); | |||
return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {}); | |||
return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], | |||
op.param(), {}); | |||
} | |||
} | |||
OP_TRAIT_REG(AssertEqual, AssertEqual) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // assert_equal | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace assert_equal | |||
} // namespace | |||
namespace { namespace roi_align { | |||
VarNodeArray apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace roi_align { | |||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const ROIAlign&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{op.make_name()}; | |||
auto* opr = opr::ROIAlign::make( | |||
inputs[0], inputs[1], op.param(), config).node()->owner_opr(); | |||
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config) | |||
.node() | |||
->owner_opr(); | |||
return {opr->output(0), opr->output(1)}; | |||
} | |||
OP_TRAIT_REG(ROIAlign, ROIAlign) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // roi_align | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace roi_align | |||
} // namespace | |||
namespace { namespace correlation { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace correlation { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Correlation&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::Correlation::make( | |||
inputs[0], inputs[1], op.param(), config); | |||
return opr::Correlation::make(inputs[0], inputs[1], op.param(), config); | |||
} | |||
OP_TRAIT_REG(Correlation, Correlation) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // correlation | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace correlation | |||
} // namespace | |||
#if MGB_CUDA | |||
namespace { namespace nvof { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace nvof { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const NvOf&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::NvOf::make(inputs[0], op.param(), config); | |||
} | |||
OP_TRAIT_REG(NvOf, NvOf) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // nvof | |||
OP_TRAIT_REG(NvOf, NvOf).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace nvof | |||
} // namespace | |||
#endif | |||
namespace { namespace linspace { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace linspace { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Linspace&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
cg::OperatorNodeConfig config{op.comp_node}; | |||
config.name(op.make_name()); | |||
return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), | |||
config); | |||
} | |||
OP_TRAIT_REG(Linspace, Linspace) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // linspace | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace linspace | |||
} // namespace | |||
namespace { namespace eye { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace eye { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Eye&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
cg::OperatorNodeConfig config{op.comp_node}; | |||
@@ -490,58 +483,59 @@ auto apply_on_var_node( | |||
opr::Eye::Param param{op.k, op.dtype.enumv()}; | |||
return opr::Eye::make(inputs[0], param, config); | |||
} | |||
OP_TRAIT_REG(Eye, Eye) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // eye | |||
OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace eye | |||
} // namespace | |||
namespace { namespace roi_pooling { | |||
VarNodeArray apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace roi_pooling { | |||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const ROIPooling&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
OperatorNodeConfig config{op.make_name()}; | |||
auto* opr = opr::ROIPooling::make( | |||
inputs[0], inputs[1], inputs[2], op.param(), config | |||
).node()->owner_opr(); | |||
auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], | |||
op.param(), config) | |||
.node() | |||
->owner_opr(); | |||
return {opr->output(0), opr->output(1)}; | |||
} | |||
OP_TRAIT_REG(ROIPooling, ROIPooling) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // roi_pooling | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace roi_pooling | |||
} // namespace | |||
namespace { namespace remap { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace remap { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Remap&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::Remap::make(inputs[0], inputs[1], op.param(), config); | |||
} | |||
OP_TRAIT_REG(Remap, Remap) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // remap | |||
OP_TRAIT_REG(Remap, Remap).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace remap | |||
} // namespace | |||
namespace { | |||
auto get_index( | |||
const VarNodeArray& inputs, size_t vidx, | |||
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) { | |||
const VarNodeArray& inputs, size_t vidx, | |||
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) { | |||
size_t length = mask.size(); | |||
opr::Subtensor::IndexDesc ret(length); | |||
for (size_t i = 0; i < length; ++ i) { | |||
for (size_t i = 0; i < length; ++i) { | |||
auto&& [axis, begin, end, step, idx] = mask[i]; | |||
ret[i].axis = axis; | |||
if (idx) { | |||
ret[i].idx = inputs[vidx++]; | |||
} else { | |||
mgb_assert(begin || end || step); | |||
if (begin) ret[i].begin = inputs[vidx++]; | |||
if (end) ret[i].end = inputs[vidx++]; | |||
if (step) ret[i].step = inputs[vidx++]; | |||
if (begin) | |||
ret[i].begin = inputs[vidx++]; | |||
if (end) | |||
ret[i].end = inputs[vidx++]; | |||
if (step) | |||
ret[i].step = inputs[vidx++]; | |||
} | |||
} | |||
mgb_assert(vidx == inputs.size()); | |||
@@ -550,19 +544,19 @@ auto get_index( | |||
#define IN1 inputs[0] | |||
#define IN2 inputs[0], inputs[1] | |||
#define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \ | |||
namespace NAME##_impl { \ | |||
auto apply_on_var_node( \ | |||
const OpDef& def, \ | |||
const VarNodeArray& inputs) { \ | |||
auto&& op = static_cast<const NAME&>(def); \ | |||
OperatorNodeConfig config{op.make_name()}; \ | |||
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \ | |||
} \ | |||
OP_TRAIT_REG(NAME, NAME) \ | |||
.apply_on_var_node(apply_on_var_node) \ | |||
.fallback(); \ | |||
} | |||
#define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \ | |||
namespace NAME##_impl { \ | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { \ | |||
auto&& op = static_cast<const NAME&>(def); \ | |||
OperatorNodeConfig config{op.make_name()}; \ | |||
return opr::NAME::make(IN##NR_INPUT, \ | |||
get_index(inputs, NR_INPUT, op.items), \ | |||
config); \ | |||
} \ | |||
OP_TRAIT_REG(NAME, NAME) \ | |||
.apply_on_var_node(apply_on_var_node) \ | |||
.fallback(); \ | |||
} | |||
FANCY_INDEXING_IMPL(Subtensor, 1) | |||
FANCY_INDEXING_IMPL(SetSubtensor, 2) | |||
@@ -580,76 +574,88 @@ FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2) | |||
#undef FANCY_INDEXING_IMPL | |||
#undef IN1 | |||
#undef IN2 | |||
} // anonymous namespace | |||
} // anonymous namespace | |||
namespace { namespace fake_quant { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace fake_quant { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const FakeQuant&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), | |||
config); | |||
} | |||
OP_TRAIT_REG(FakeQuant, FakeQuant) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // fake_quant | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace fake_quant | |||
} // namespace | |||
namespace { namespace tqt { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace tqt { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const TQT&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::TQT::make(inputs[0], inputs[1], op.param(), config); | |||
} | |||
OP_TRAIT_REG(TQT, TQT) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // tqt | |||
OP_TRAIT_REG(TQT, TQT).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace tqt | |||
} // namespace | |||
namespace { namespace elemwise_multi_type { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace elemwise_multi_type { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const ElemwiseMultiType&>(def); | |||
OperatorNodeConfig config{op.dtype}; | |||
config.name(op.make_name()); | |||
return opr::ElemwiseMultiType::make(inputs, op.param(), config); | |||
} | |||
OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // elemwise_multi_type | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace elemwise_multi_type | |||
} // namespace | |||
namespace { namespace svd { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace svd { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const SVD&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::SVD::make(inputs[0], op.param(), config)[0] | |||
.node()->owner_opr()->usable_output(); | |||
.node() | |||
->owner_opr() | |||
->usable_output(); | |||
} | |||
OP_TRAIT_REG(SVD, SVD) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // svd | |||
OP_TRAIT_REG(SVD, SVD).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace svd | |||
} // namespace | |||
namespace { namespace images2neibs { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
namespace { | |||
namespace images2neibs { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Images2Neibs&>(def); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::Images2Neibs::make(inputs[0], op.param(), config); | |||
} | |||
OP_TRAIT_REG(Images2Neibs, Images2Neibs) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // images2neibs | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace images2neibs | |||
} // namespace | |||
namespace { | |||
namespace lsq { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const LSQ&>(def); | |||
mgb_assert(inputs.size() == 4); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::LSQ::make(inputs[0], inputs[1], inputs[2], inputs[3], | |||
op.param(), config); | |||
} | |||
OP_TRAIT_REG(LSQ, LSQ).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace lsq | |||
} // namespace | |||
} // namespace mgb::imperative | |||
} // namespace mgb::imperative |
@@ -6,22 +6,24 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "./helper.h" | |||
#include "megbrain/imperative/backward_graph_opt.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/dnn/batch_norm.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/backward_graph_opt.h" | |||
using namespace mgb; | |||
using namespace cg; | |||
using namespace imperative; | |||
template <typename T> | |||
T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, const T& outputs, const T& grads) { | |||
T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, | |||
const T& outputs, const T& grads) { | |||
T ret; | |||
size_t i = 0; | |||
for (auto&& t : inputs) { | |||
@@ -54,7 +56,9 @@ T expand_grads(const U& bg, const T& outputs) { | |||
} | |||
template <typename T> | |||
T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, const T& precomp, const T& inputs, const T& outputs, const T& grads) { | |||
T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, | |||
const T& precomp, const T& inputs, | |||
const T& outputs, const T& grads) { | |||
T ret = precomp; | |||
size_t i = 0; | |||
for (auto&& t : inputs) { | |||
@@ -75,7 +79,8 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, cons | |||
return ret; | |||
} | |||
SmallVector<TensorPtr> apply_shared_on_physical_tensor(std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) { | |||
SmallVector<TensorPtr> apply_shared_on_physical_tensor( | |||
std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) { | |||
return OpDef::apply_on_physical_tensor(*def, inputs); | |||
} | |||
@@ -83,7 +88,7 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
HostTensorGenerator<> gen; | |||
SmallVector<HostTensorND> hvs; | |||
SmallVector<TensorPtr> inputs; | |||
for(size_t i = 0; i < 2; ++ i) { | |||
for (size_t i = 0; i < 2; ++i) { | |||
hvs.push_back(*gen({42})); | |||
inputs.push_back(Tensor::make(hvs.back())); | |||
} | |||
@@ -97,7 +102,8 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
for (auto&& i : inputs) { | |||
input_descs.push_back({i->layout(), i->comp_node()}); | |||
} | |||
auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, {true}); | |||
auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, | |||
{true}); | |||
auto&& save_for_backward = result.save_for_backward; | |||
auto&& input_has_grad = result.input_has_grad; | |||
@@ -106,9 +112,9 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
hvs.push_back(*gen({42})); | |||
inputs.push_back(Tensor::make(hvs.back())); | |||
mgb_assert(save_for_backward.size() == inputs.size()); | |||
for (size_t i = 0; i < inputs.size(); ++ i) { | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (!save_for_backward[i]) { | |||
inputs[i].reset(); // drop unused tensor | |||
inputs[i].reset(); // drop unused tensor | |||
} | |||
} | |||
SmallVector<TensorPtr> backward_graph_inputs; | |||
@@ -118,13 +124,11 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
} | |||
} | |||
inputs.clear(); | |||
auto input_grads = result.backward.apply( | |||
backward_graph_inputs, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x){ return x; } | |||
); | |||
auto input_grads = result.backward.apply(backward_graph_inputs, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x) { return x; }); | |||
mgb_assert(input_grads.size() == input_has_grad.size()); | |||
for (size_t i = 0; i < input_has_grad.size(); ++ i) { | |||
for (size_t i = 0; i < input_has_grad.size(); ++i) { | |||
mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | |||
} | |||
@@ -133,9 +137,10 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
res.emplace_back(); | |||
res.back().copy_from(i->dev_tensor()).sync(); | |||
} | |||
for (size_t i = 0; i < 42; ++ i) { | |||
for (size_t j = 0; j < 1; ++ j) { | |||
ASSERT_EQ(hvs[2].ptr<float>()[i] * hvs[j].ptr<float>()[i], res[j ^ 1].ptr<float>()[i]); | |||
for (size_t i = 0; i < 42; ++i) { | |||
for (size_t j = 0; j < 1; ++j) { | |||
ASSERT_EQ(hvs[2].ptr<float>()[i] * hvs[j].ptr<float>()[i], | |||
res[j ^ 1].ptr<float>()[i]); | |||
} | |||
} | |||
} | |||
@@ -152,7 +157,8 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
input_descs.push_back({a->layout(), a->comp_node()}); | |||
auto result = OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); | |||
auto result = | |||
OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); | |||
auto&& save_for_backward = result.save_for_backward; | |||
auto&& input_has_grad = result.input_has_grad; | |||
@@ -160,9 +166,9 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
inputs.push_back(outputs[0]); | |||
inputs.push_back(dc); | |||
mgb_assert(save_for_backward.size() == inputs.size()); | |||
for (size_t i = 0; i < inputs.size(); ++ i) { | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (!save_for_backward[i]) { | |||
inputs[i].reset(); // drop unused tensor | |||
inputs[i].reset(); // drop unused tensor | |||
} | |||
} | |||
SmallVector<TensorPtr> backward_graph_inputs; | |||
@@ -172,19 +178,17 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
} | |||
} | |||
inputs.clear(); | |||
auto input_grads = result.backward.apply( | |||
backward_graph_inputs, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x){ return x; } | |||
); | |||
auto input_grads = result.backward.apply(backward_graph_inputs, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x) { return x; }); | |||
mgb_assert(input_grads.size() == input_has_grad.size()); | |||
for (size_t i = 0; i < input_has_grad.size(); ++ i) { | |||
for (size_t i = 0; i < input_has_grad.size(); ++i) { | |||
mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | |||
} | |||
HostTensorND hv; | |||
hv.copy_from(input_grads[0]->dev_tensor()).sync(); | |||
for (size_t i = 0; i < 42; ++ i) { | |||
for (size_t i = 0; i < 42; ++i) { | |||
ASSERT_EQ(host_dc->ptr<float>()[i], hv.ptr<float>()[i]); | |||
} | |||
} | |||
@@ -192,7 +196,7 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
TEST(TestImperative, BatchNormGrad) { | |||
auto cn = CompNode::load("xpux"); | |||
using Param = opr::BatchNorm::Param; | |||
size_t N=2, C=3, H=5, W=5; | |||
size_t N = 2, C = 3, H = 5, W = 5; | |||
LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | |||
LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | |||
{ | |||
@@ -202,7 +206,8 @@ TEST(TestImperative, BatchNormGrad) { | |||
param.fwd_mode = Param::FwdMode::TRAINING; | |||
attr.param.write_pod(param); | |||
OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, | |||
{true, true ,true, false, false}, {false, false, false, false, true}); | |||
{true, true, true, false, false}, | |||
{false, false, false, false, true}); | |||
} | |||
{ | |||
auto op = OprAttr::make("BatchNorm"); | |||
@@ -210,8 +215,8 @@ TEST(TestImperative, BatchNormGrad) { | |||
Param param; | |||
param.fwd_mode = Param::FwdMode::TRAINING; | |||
attr.param.write_pod(param); | |||
OpDef::make_backward_graph(attr, {inp, stat, stat}, | |||
{true, true ,true}, {false, false, true}); | |||
OpDef::make_backward_graph(attr, {inp, stat, stat}, {true, true, true}, | |||
{false, false, true}); | |||
} | |||
} | |||
@@ -220,7 +225,8 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
LogicalTensorDesc desc = {TensorLayout(dtype::Float32()), cn}; | |||
HostTensorGenerator<> gen; | |||
auto op = std::shared_ptr<OpDef>(Elemwise::make(Elemwise::Mode::ADD)); | |||
auto bg = OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true}); | |||
auto bg = | |||
OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true}); | |||
auto obg = OptimizedBackwardGraphResult(bg); | |||
ASSERT_EQ(obg.save_for_backward.size(), 4); | |||
ASSERT_FALSE(obg.save_for_backward[0]); | |||
@@ -235,30 +241,30 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
auto dc_tn = Tensor::make(*dc_hv); | |||
auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; | |||
auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
auto grads = expand_grads(bg, bg.backward.apply( | |||
backward_graph_inputs, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x){ return x; } | |||
)); | |||
auto backward_graph_inputs = | |||
prepare_backward_graph_inputs<SmallVector<TensorPtr>>( | |||
bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
auto grads = | |||
expand_grads(bg, bg.backward.apply(backward_graph_inputs, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x) { return x; })); | |||
auto precomp = obg.precomp.apply( | |||
SmallVector<TensorPtr>{a_tn, b_tn, c_tn}, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x){ return x; } | |||
); | |||
auto precomp = obg.precomp.apply(SmallVector<TensorPtr>{a_tn, b_tn, c_tn}, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x) { return x; }); | |||
ASSERT_EQ(precomp.size(), 2); | |||
ASSERT_EQ(precomp[0]->shape().ndim, 1); | |||
ASSERT_LE(precomp[0]->shape()[0], 2); | |||
ASSERT_EQ(precomp[1]->shape().ndim, 1); | |||
ASSERT_LE(precomp[1]->shape()[0], 2); | |||
auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
auto grads2 = expand_grads(obg, obg.backward.apply( | |||
backward_inputs, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x){ return x; } | |||
)); | |||
auto backward_inputs = | |||
prepare_optimized_backward_inputs<SmallVector<TensorPtr>>( | |||
obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
auto grads2 = expand_grads( | |||
obg, | |||
obg.backward.apply(backward_inputs, apply_shared_on_physical_tensor, | |||
[&](auto&& x) { return x; })); | |||
ASSERT_EQ(grads2.size(), 2); | |||
MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value()); | |||
@@ -271,6 +271,7 @@ def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">; | |||
def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; | |||
def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>; | |||
def TQT: MgbHashableOp<"TQT", [TQTParam]>; | |||
def LSQ: MgbHashableOp<"LSQ", [LSQParam]>; | |||
def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { | |||
let extraArguments = (ins | |||
MgbDTypeAttr:$dtype | |||
@@ -324,5 +324,7 @@ decl_opr('FakeQuant', | |||
decl_opr('TQT', | |||
inputs=[Doc('src','input tensor'),Doc('scale','scale tensor')], | |||
params='TQT') | |||
decl_opr('LSQ', | |||
inputs=[Doc('src','input tensor'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor'),Doc('grad_scale','grad scale tensor')], | |||
params='LSQ') | |||
# vim: ft=python |
@@ -6,20 +6,22 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||
#include "megbrain/opr/dnn/batch_norm.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/dnn/correlation.h" | |||
#include "megbrain/opr/dnn/fake_quant.h" | |||
#include "megbrain/opr/dnn/images2neibs.h" | |||
#include "megbrain/opr/dnn/pooling.h" | |||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||
#include "megbrain/opr/dnn/roi_pooling.h" | |||
#include "megbrain/opr/dnn/roi_align.h" | |||
#include "megbrain/opr/dnn/local.h" | |||
#include "megbrain/opr/dnn/lrn.h" | |||
#include "megbrain/opr/dnn/fake_quant.h" | |||
#include "megbrain/opr/dnn/lsq.h" | |||
#include "megbrain/opr/dnn/pooling.h" | |||
#include "megbrain/opr/dnn/roi_align.h" | |||
#include "megbrain/opr/dnn/roi_pooling.h" | |||
#include "megbrain/opr/dnn/tqt.h" | |||
#include "megbrain/serialization/sereg.h" | |||
#include "megdnn/opr_param_defs.h" | |||
@@ -183,7 +185,8 @@ struct ConvLoadDumpImpl { | |||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
auto&& opr = opr_.cast_final_safe<Opr>(); | |||
ctx.write_param<ConvParam>(opr.param()); | |||
ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy_transient()); | |||
ctx.write_param<megdnn::param::ExecutionPolicy>( | |||
opr.execution_policy_transient()); | |||
} | |||
static VarNode* make(const cg::VarNodeArray& inputs, const ConvParam& param, | |||
@@ -252,6 +255,20 @@ struct OprMaker<opr::TQTBackward, 3> { | |||
}; | |||
template <> | |||
struct OprMaker<opr::LSQBackward, 5> { | |||
using Param = opr::LSQBackward::Param; | |||
static cg::OperatorNodeBase* make(const Param& param, | |||
const cg::VarNodeArray& i, | |||
ComputingGraph& graph, | |||
const OperatorNodeConfig& config) { | |||
MGB_MARK_USED_VAR(graph); | |||
return opr::LSQBackward::make(i[0], i[1], i[2], i[3], i[4], param, | |||
config)[0] | |||
.node() | |||
->owner_opr(); | |||
} | |||
}; | |||
template <> | |||
struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0> | |||
: public PoolingLoadDumpImpl<opr::AdaptivePoolingBackward, | |||
MakeAdaptivePoolingBackwardCaller3< | |||
@@ -587,6 +604,8 @@ MGB_SEREG_OPR(FakeQuant, 3); | |||
MGB_SEREG_OPR(FakeQuantBackward, 4); | |||
MGB_SEREG_OPR(TQT, 2); | |||
MGB_SEREG_OPR(TQTBackward, 3); | |||
MGB_SEREG_OPR(LSQ, 4); | |||
MGB_SEREG_OPR(LSQBackward, 5); | |||
} // namespace opr | |||
@@ -0,0 +1,90 @@ | |||
/** | |||
* \file src/opr/impl/dnn/lsq.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megbrain/opr/dnn/lsq.h" | |||
#include "../internal/megdnn_opr_wrapper.inl" | |||
#include "megbrain/graph/grad_impl.h" | |||
#include "megbrain/opr/basic_arith_wrapper.h" | |||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/opr/utility.h" | |||
using namespace mgb; | |||
using namespace opr; | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSQForward); | |||
MEGDNN_OPR_INIT4(LSQForward, "lsq_fwd"); | |||
#ifdef MGB_ENABLE_GRAD | |||
MGB_IMPL_OPR_GRAD(LSQForward) { | |||
SymbolVarArray grad = | |||
LSQBackward::make(out_grad[0], opr.input(0), opr.input(1), | |||
opr.input(2), opr.input(3), opr.param()); | |||
if (wrt_idx == 0) { | |||
return grad[0].node(); | |||
} else if (wrt_idx == 1) { | |||
return reduce_sum(grad[1], GetVarShape::make(opr.input(wrt_idx))) | |||
.node(); | |||
} else { | |||
return nullptr; | |||
} | |||
} | |||
#endif | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSQBackward); | |||
LSQBackward::LSQBackward(VarNode* y_grad, VarNode* x, VarNode* scale, | |||
VarNode* zero_point, VarNode* grad_scale, | |||
const Param& param, const OperatorNodeConfig& config) | |||
: Super({x->owner_graph(), | |||
config, | |||
"lsq_bwd", | |||
{y_grad, x, scale, zero_point, grad_scale}}, | |||
1, true) { | |||
init_megdnn_opr(*this, param); | |||
add_input({y_grad, x, scale, zero_point, grad_scale}); | |||
} | |||
SymbolVarArray LSQBackward::make(SymbolVar y_grad, SymbolVar x, SymbolVar scale, | |||
SymbolVar zero_point, SymbolVar grad_scale, | |||
const Param& param, | |||
const OperatorNodeConfig& config) { | |||
auto&& out = x.node()->owner_graph() | |||
->insert_opr(std::make_unique<LSQBackward>( | |||
y_grad.node(), x.node(), scale.node(), | |||
zero_point.node(), grad_scale.node(), param, | |||
config)) | |||
->output(); | |||
SymbolVarArray ret(out.size()); | |||
for (size_t i = 0; i < ret.size(); ++i) { | |||
ret[i] = out[i]; | |||
} | |||
return ret; | |||
} | |||
void LSQBackward::init_output_static_infer_desc() { | |||
using namespace cg::static_infer; | |||
auto&& mgr = owner_graph()->static_infer_manager(); | |||
mgr.register_shape_infer(output(0), | |||
ShapeInferDesc::make_identity(input(1))); | |||
mgr.register_shape_infer(output(1), | |||
ShapeInferDesc::make_identity(input(1))); | |||
this->init_output_static_infer_desc_workspace( | |||
intl::AutoAddWorkspaceNeedLimitGetter<megdnn::LSQBackward>::val); | |||
} | |||
void LSQBackward::init_output_dtype() { | |||
output(0)->dtype(input(1)->dtype()); | |||
output(1)->dtype(input(2)->dtype()); | |||
} |
@@ -0,0 +1,50 @@ | |||
/** | |||
* \file src/opr/include/megbrain/opr/dnn/lsq.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
#include "megdnn/oprs.h" | |||
namespace mgb { | |||
namespace opr { | |||
MGB_DEFINE_OPR_CLASS(LSQForward, | |||
intl::MegDNNOprWrapperFwd<megdnn::LSQForward>) // { | |||
public: | |||
LSQForward(VarNode* src, VarNode* scale, VarNode* zero_point, | |||
VarNode* grad_scale, const Param& param, | |||
const OperatorNodeConfig& config); | |||
static SymbolVar make(SymbolVar src, SymbolVar scale, SymbolVar zero_point, | |||
SymbolVar grad_scale, const Param& param = {}, | |||
const OperatorNodeConfig& config = {}); | |||
}; | |||
using LSQ = LSQForward; | |||
MGB_DEFINE_OPR_CLASS(LSQBackward, | |||
intl::MegDNNOprWrapperBwd<megdnn::LSQBackward>) // { | |||
public: | |||
LSQBackward(VarNode* y_grad, VarNode* x, VarNode* scale, VarNode* zero_point, | |||
VarNode* grad_scale, const Param& param, | |||
const OperatorNodeConfig& config); | |||
static SymbolVarArray make(SymbolVar y_grad, SymbolVar x, SymbolVar scale, | |||
SymbolVar zero_point, SymbolVar grad_scale, | |||
const Param& param = {}, | |||
const OperatorNodeConfig& config = {}); | |||
private: | |||
void init_output_static_infer_desc() override; | |||
void init_output_dtype() override; | |||
}; | |||
} // namespace opr | |||
} // namespace mgb |
@@ -0,0 +1,78 @@ | |||
/** | |||
* \file src/opr/test/dnn/lsq.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megbrain/opr/dnn/lsq.h" | |||
#include "megbrain/comp_node_env.h" | |||
#include "megbrain/test/autocheck.h" | |||
using namespace std; | |||
using namespace mgb; | |||
namespace { | |||
void run() { | |||
using Checker = AutoOprChecker<4, 1>; | |||
auto make_graph = | |||
[&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||
auto o0 = opr::LSQForward::make(inputs[0], inputs[1], inputs[2], | |||
inputs[3]); | |||
return {o0}; | |||
}; | |||
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||
auto opr = MegDNNHandle::get( | |||
CompNodeEnv::from_comp_node(CompNode::default_cpu())) | |||
->create_operator<megdnn::LSQForward>(); | |||
dest[0].dtype(dtype::Float32()) | |||
.comp_node(inp[0]->comp_node()) | |||
.resize(inp[0]->shape()); | |||
opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), | |||
inp[3]->as_megdnn(), dest[0].as_megdnn(), {}); | |||
}; | |||
auto gen = [&](HostTensorND& src) { | |||
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> | |||
src_gen(10.f); | |||
src = *src_gen(src.shape(), src.comp_node()); | |||
}; | |||
Checker::RunOptions opt; | |||
opt.numdiff_max_err = 1e-5; | |||
Checker checker{make_graph, fwd}; | |||
checker.set_input_generator(0, gen) | |||
.set_input_generator(1, gen) | |||
.set_input_generator(2, gen) | |||
.set_input_generator(3, gen) | |||
.set_input_allow_grad(0, false) | |||
.set_input_allow_grad(1, false) | |||
.set_input_allow_grad(2, false) | |||
.set_input_allow_grad(3, false) | |||
.set_output_allow_grad(0, false); | |||
checker.run({TensorShape{1, 2, 3, 4}, TensorShape{1}, TensorShape{1}, | |||
TensorShape{1}}, | |||
opt) | |||
.run({TensorShape{2, 3, 8, 8}, TensorShape{1}, TensorShape{1}, | |||
TensorShape{1}}, | |||
opt) | |||
.run({TensorShape{1, 3, 4, 4}, TensorShape{1}, TensorShape{1}, | |||
TensorShape{1}}, | |||
opt); | |||
} | |||
} // anonymous namespace | |||
TEST(TestOprDNN, LSQForward) { | |||
REQUIRE_GPU(1); | |||
run(); | |||
} |
@@ -107,6 +107,7 @@ union OperatorParam { | |||
param.FakeQuant = 73, | |||
param.TQT = 74, | |||
param.Correlation = 75, | |||
param.LSQ = 76, | |||
} | |||
table Operator { | |||