@@ -1741,6 +1741,67 @@ protected: | |||||
const TensorLayout& grad_s, size_t workspace_in_bytes); | const TensorLayout& grad_s, size_t workspace_in_bytes); | ||||
}; | }; | ||||
class LSQBase : public OperatorBase { | |||||
DEF_OPR_IMPL_CTOR(LSQBase, OperatorBase); | |||||
DEF_OPR_PARAM(LSQ); | |||||
protected: | |||||
void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output); | |||||
void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale, | |||||
const TensorLayout& zero_point, | |||||
const TensorLayout& grad_scale, | |||||
const TensorLayout& output); | |||||
}; | |||||
class LSQForward : public LSQBase { | |||||
DEF_OPR_IMPL(LSQForward, LSQBase, 4, 1); | |||||
public: | |||||
virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||||
_megdnn_tensor_in zero_point, | |||||
_megdnn_tensor_in grad_scale, _megdnn_tensor_out output, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& input, const TensorLayout& scale, | |||||
const TensorLayout& zero_point, | |||||
const TensorLayout& grad_scale, TensorLayout& output); | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& input, | |||||
const TensorLayout& scale, | |||||
const TensorLayout& zero_point, | |||||
const TensorLayout& grad_scale, | |||||
const TensorLayout& output) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout& input, const TensorLayout& scale, | |||||
const TensorLayout& zero_point, | |||||
const TensorLayout& grad_scale, const TensorLayout& output, | |||||
size_t workspace_in_bytes); | |||||
}; | |||||
using LSQ = LSQForward; | |||||
class LSQBackward : public LSQBase { | |||||
DEF_OPR_IMPL(LSQBackward, LSQBase, 5, 2); | |||||
public: | |||||
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||||
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point, | |||||
_megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, | |||||
_megdnn_tensor_out grad_s, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||||
const TensorLayout& input, | |||||
const TensorLayout& scale, | |||||
const TensorLayout& zero_point, | |||||
const TensorLayout& grad_scale, | |||||
const TensorLayout& grad_x, | |||||
const TensorLayout& grad_s) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout& diff, const TensorLayout& input, | |||||
const TensorLayout& scale, const TensorLayout& zero_point, | |||||
const TensorLayout& grad_scale, const TensorLayout& grad_x, | |||||
const TensorLayout& grad_s, size_t workspace_in_bytes); | |||||
}; | |||||
} // namespace megdnn | } // namespace megdnn | ||||
#include "megdnn/internal/opr_header_epilogue.h" | #include "megdnn/internal/opr_header_epilogue.h" | ||||
@@ -1124,3 +1124,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||||
add_fields('int32', 'qmin', '-2147483648'). | add_fields('int32', 'qmin', '-2147483648'). | ||||
add_fields('int32', 'qmax', '2147483647') | add_fields('int32', 'qmax', '2147483647') | ||||
) | ) | ||||
(pdef('LSQ'). | |||||
add_fields('int32', 'qmin', '-2147483648'). | |||||
add_fields('int32', 'qmax', '2147483647') | |||||
) | |||||
@@ -37,6 +37,7 @@ namespace megdnn { | |||||
megdnn_assert(size, "uninitialized ElemwiseOpParamN"); | megdnn_assert(size, "uninitialized ElemwiseOpParamN"); | ||||
} | } | ||||
template struct ElemwiseOpParamN<7>; | |||||
template struct ElemwiseOpParamN<6>; | template struct ElemwiseOpParamN<6>; | ||||
template struct ElemwiseOpParamN<5>; | template struct ElemwiseOpParamN<5>; | ||||
template struct ElemwiseOpParamN<4>; | template struct ElemwiseOpParamN<4>; | ||||
@@ -208,7 +208,9 @@ private: | |||||
cb(FakeQuantBackward) \ | cb(FakeQuantBackward) \ | ||||
cb(TQTForward) \ | cb(TQTForward) \ | ||||
cb(TQTBackward) \ | cb(TQTBackward) \ | ||||
cb(CheckHasInf) | |||||
cb(CheckHasInf) \ | |||||
cb(LSQForward) \ | |||||
cb(LSQBackward) | |||||
/*! | /*! | ||||
* \brief specialize HandleImpl::create_operator for a single opr type; | * \brief specialize HandleImpl::create_operator for a single opr type; | ||||
@@ -0,0 +1,69 @@ | |||||
/** | |||||
* \file dnn/src/common/lsq.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
void LSQBase::deduce_layout_fwd(const TensorLayout& input, | |||||
TensorLayout& output) { | |||||
output = TensorLayout(input, input.dtype); | |||||
} | |||||
void LSQBase::check_layout_fwd(const TensorLayout& input, | |||||
const TensorLayout& scale, | |||||
const TensorLayout& zero_point, | |||||
const TensorLayout& grad_scale, | |||||
const TensorLayout& output) { | |||||
megdnn_assert(input.dtype == dtype::Float32()); | |||||
megdnn_assert(scale.dtype == dtype::Float32()); | |||||
megdnn_assert(zero_point.dtype == dtype::Float32()); | |||||
megdnn_assert(grad_scale.dtype == dtype::Float32()); | |||||
TensorLayout expected; | |||||
deduce_layout_fwd(input, expected); | |||||
megdnn_assert_eq_layout(expected, output); | |||||
} | |||||
void LSQForward::deduce_layout(const TensorLayout& input, | |||||
const TensorLayout& /* scale */, | |||||
const TensorLayout& /*zero_point*/, | |||||
const TensorLayout& /*grad_scale*/, | |||||
TensorLayout& output) { | |||||
deduce_layout_fwd(input, output); | |||||
} | |||||
void LSQForward::check_exec(const TensorLayout& input, | |||||
const TensorLayout& scale, | |||||
const TensorLayout& zero_point, | |||||
const TensorLayout& grad_scale, | |||||
const TensorLayout& output, | |||||
size_t workspace_in_bytes) { | |||||
check_layout_fwd(input, scale, zero_point, grad_scale, output); | |||||
auto required_workspace_space = get_workspace_in_bytes( | |||||
input, scale, zero_point, grad_scale, output); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_space); | |||||
} | |||||
void LSQBackward::check_exec( | |||||
const TensorLayout& diff, const TensorLayout& input, | |||||
const TensorLayout& scale, const TensorLayout& zero_point, | |||||
const TensorLayout& grad_scale, const TensorLayout& grad_x, | |||||
const TensorLayout& grad_s, size_t workspace_in_bytes) { | |||||
megdnn_assert_eq_shape(diff, input); | |||||
megdnn_assert_eq_shape(grad_x, input); | |||||
auto required_worspace_space = get_workspace_in_bytes( | |||||
diff, input, scale, zero_point, grad_scale, grad_x, grad_s); | |||||
megdnn_assert(workspace_in_bytes >= required_worspace_space); | |||||
} | |||||
} // namespace megdnn |
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
@@ -121,6 +122,8 @@ DEF(UniformRNG, 1, true, true); | |||||
DEF(GaussianRNG, 1, true, true); | DEF(GaussianRNG, 1, true, true); | ||||
DEF(ChecksumForward, 1, true, false); | DEF(ChecksumForward, 1, true, false); | ||||
DEF(CheckHasInf, 2, true, true); | DEF(CheckHasInf, 2, true, true); | ||||
DEF(LSQForward, 5, true, true); | |||||
DEF(LSQBackward, 7, true, false); | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -947,6 +947,119 @@ struct OpCallerUniform<Op, 5, PVis> { | |||||
} | } | ||||
}; | }; | ||||
//! specialization for arity == 6 | |||||
template <class Op, class PVis> | |||||
struct OpCallerUniform<Op, 6, PVis> { | |||||
Op op; | |||||
PVis par[6]; | |||||
static const uint32_t packed_size = PVis::packed_size; | |||||
devfunc void thread_init(uint32_t idx) { | |||||
idx = idx * packed_size; | |||||
par[0].thread_init(idx); | |||||
par[1].thread_init(idx); | |||||
par[2].thread_init(idx); | |||||
par[3].thread_init(idx); | |||||
par[4].thread_init(idx); | |||||
par[5].thread_init(idx); | |||||
} | |||||
devfunc void on(uint32_t idx) { | |||||
idx = idx * packed_size; | |||||
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx), | |||||
par[4].at(idx), par[5].at(idx)); | |||||
} | |||||
devfunc void on(uint32_t idx, uint32_t remain) { | |||||
idx = idx * packed_size; | |||||
if (remain >= packed_size) { | |||||
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), | |||||
par[3].at(idx), par[4].at(idx), par[5].at(idx)); | |||||
} else { | |||||
auto ptr0 = par[0].ptr(); | |||||
auto ptr1 = par[1].ptr(); | |||||
auto ptr2 = par[2].ptr(); | |||||
auto ptr3 = par[3].ptr(); | |||||
auto ptr4 = par[4].ptr(); | |||||
auto ptr5 = par[5].ptr(); | |||||
for (int i = 0; i < remain; i++) { | |||||
op(idx + i, ptr0[par[0].offset(idx + i)], | |||||
ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)], | |||||
ptr3[par[3].offset(idx + i)], ptr4[par[4].offset(idx + i)], | |||||
ptr5[par[5].offset(idx + i)]); | |||||
} | |||||
} | |||||
} | |||||
devfunc void next() { | |||||
par[0].next(); | |||||
par[1].next(); | |||||
par[2].next(); | |||||
par[3].next(); | |||||
par[4].next(); | |||||
par[5].next(); | |||||
} | |||||
}; | |||||
//! specialization for arity == 7 | |||||
template <class Op, class PVis> | |||||
struct OpCallerUniform<Op, 7, PVis> { | |||||
Op op; | |||||
PVis par[7]; | |||||
static const uint32_t packed_size = PVis::packed_size; | |||||
devfunc void thread_init(uint32_t idx) { | |||||
idx = idx * packed_size; | |||||
par[0].thread_init(idx); | |||||
par[1].thread_init(idx); | |||||
par[2].thread_init(idx); | |||||
par[3].thread_init(idx); | |||||
par[4].thread_init(idx); | |||||
par[5].thread_init(idx); | |||||
par[6].thread_init(idx); | |||||
} | |||||
devfunc void on(uint32_t idx) { | |||||
idx = idx * packed_size; | |||||
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx), | |||||
par[4].at(idx), par[5].at(idx), par[6].at(idx)); | |||||
} | |||||
devfunc void on(uint32_t idx, uint32_t remain) { | |||||
idx = idx * packed_size; | |||||
if (remain >= packed_size) { | |||||
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), | |||||
par[3].at(idx), par[4].at(idx), par[5].at(idx), par[6].at(idx)); | |||||
} else { | |||||
auto ptr0 = par[0].ptr(); | |||||
auto ptr1 = par[1].ptr(); | |||||
auto ptr2 = par[2].ptr(); | |||||
auto ptr3 = par[3].ptr(); | |||||
auto ptr4 = par[4].ptr(); | |||||
auto ptr5 = par[5].ptr(); | |||||
auto ptr6 = par[6].ptr(); | |||||
for (int i = 0; i < remain; i++) { | |||||
op(idx + i, ptr0[par[0].offset(idx + i)], | |||||
ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)], | |||||
ptr3[par[3].offset(idx + i)], ptr4[par[4].offset(idx + i)], | |||||
ptr5[par[5].offset(idx + i)], ptr6[par[6].offset(idx + i)]); | |||||
} | |||||
} | |||||
} | |||||
devfunc void next() { | |||||
par[0].next(); | |||||
par[1].next(); | |||||
par[2].next(); | |||||
par[3].next(); | |||||
par[4].next(); | |||||
par[5].next(); | |||||
par[6].next(); | |||||
} | |||||
}; | |||||
/*! | /*! | ||||
* \brief call binary (i.e. arity == 2) operator with different param | * \brief call binary (i.e. arity == 2) operator with different param | ||||
* visitors | * visitors | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/common/handle_impl.h" | #include "src/common/handle_impl.h" | ||||
@@ -15,6 +16,7 @@ | |||||
#include "src/cuda/add_update/opr_impl.h" | #include "src/cuda/add_update/opr_impl.h" | ||||
#include "src/cuda/argmxx/opr_impl.h" | #include "src/cuda/argmxx/opr_impl.h" | ||||
#include "src/cuda/argsort/opr_impl.h" | #include "src/cuda/argsort/opr_impl.h" | ||||
#include "src/cuda/batch_conv_bias/opr_impl.h" | |||||
#include "src/cuda/batch_normalization/opr_impl.h" | #include "src/cuda/batch_normalization/opr_impl.h" | ||||
#include "src/cuda/batched_matrix_mul/opr_impl.h" | #include "src/cuda/batched_matrix_mul/opr_impl.h" | ||||
#include "src/cuda/check_has_inf/opr_impl.h" | #include "src/cuda/check_has_inf/opr_impl.h" | ||||
@@ -35,6 +37,7 @@ | |||||
#include "src/cuda/elemwise/opr_impl.h" | #include "src/cuda/elemwise/opr_impl.h" | ||||
#include "src/cuda/elemwise_multi_type/opr_impl.h" | #include "src/cuda/elemwise_multi_type/opr_impl.h" | ||||
#include "src/cuda/eye/opr_impl.h" | #include "src/cuda/eye/opr_impl.h" | ||||
#include "src/cuda/fake_quant/opr_impl.h" | |||||
#include "src/cuda/flip/opr_impl.h" | #include "src/cuda/flip/opr_impl.h" | ||||
#include "src/cuda/gaussian_blur/opr_impl.h" | #include "src/cuda/gaussian_blur/opr_impl.h" | ||||
#include "src/cuda/group_local/opr_impl.h" | #include "src/cuda/group_local/opr_impl.h" | ||||
@@ -45,6 +48,7 @@ | |||||
#include "src/cuda/local/opr_impl.h" | #include "src/cuda/local/opr_impl.h" | ||||
#include "src/cuda/local_share/opr_impl.h" | #include "src/cuda/local_share/opr_impl.h" | ||||
#include "src/cuda/lrn/opr_impl.h" | #include "src/cuda/lrn/opr_impl.h" | ||||
#include "src/cuda/lsq/opr_impl.h" | |||||
#include "src/cuda/mask_conv/opr_impl.h" | #include "src/cuda/mask_conv/opr_impl.h" | ||||
#include "src/cuda/matrix_inverse/opr_impl.h" | #include "src/cuda/matrix_inverse/opr_impl.h" | ||||
#include "src/cuda/matrix_mul/opr_impl.h" | #include "src/cuda/matrix_mul/opr_impl.h" | ||||
@@ -56,9 +60,11 @@ | |||||
#include "src/cuda/reduce/opr_impl.h" | #include "src/cuda/reduce/opr_impl.h" | ||||
#include "src/cuda/relayout/opr_impl.h" | #include "src/cuda/relayout/opr_impl.h" | ||||
#include "src/cuda/relayout_format/opr_impl.h" | #include "src/cuda/relayout_format/opr_impl.h" | ||||
#include "src/cuda/remap/opr_impl.h" | |||||
#include "src/cuda/repeat/opr_impl.h" | #include "src/cuda/repeat/opr_impl.h" | ||||
#include "src/cuda/resize/opr_impl.h" | #include "src/cuda/resize/opr_impl.h" | ||||
#include "src/cuda/rng/opr_impl.h" | #include "src/cuda/rng/opr_impl.h" | ||||
#include "src/cuda/roi_align/opr_impl.h" | |||||
#include "src/cuda/roi_copy/opr_impl.h" | #include "src/cuda/roi_copy/opr_impl.h" | ||||
#include "src/cuda/roi_pooling/opr_impl.h" | #include "src/cuda/roi_pooling/opr_impl.h" | ||||
#include "src/cuda/rotate/opr_impl.h" | #include "src/cuda/rotate/opr_impl.h" | ||||
@@ -70,16 +76,11 @@ | |||||
#include "src/cuda/tensor_remap/opr_impl.h" | #include "src/cuda/tensor_remap/opr_impl.h" | ||||
#include "src/cuda/tile/opr_impl.h" | #include "src/cuda/tile/opr_impl.h" | ||||
#include "src/cuda/topk/opr_impl.h" | #include "src/cuda/topk/opr_impl.h" | ||||
#include "src/cuda/tqt/opr_impl.h" | |||||
#include "src/cuda/transpose/opr_impl.h" | #include "src/cuda/transpose/opr_impl.h" | ||||
#include "src/cuda/type_cvt/opr_impl.h" | #include "src/cuda/type_cvt/opr_impl.h" | ||||
#include "src/cuda/warp_affine/opr_impl.h" | #include "src/cuda/warp_affine/opr_impl.h" | ||||
#include "src/cuda/warp_perspective/opr_impl.h" | #include "src/cuda/warp_perspective/opr_impl.h" | ||||
#include "src/cuda/local_share/opr_impl.h" | |||||
#include "src/cuda/roi_align/opr_impl.h" | |||||
#include "src/cuda/batch_conv_bias/opr_impl.h" | |||||
#include "src/cuda/remap/opr_impl.h" | |||||
#include "src/cuda/fake_quant/opr_impl.h" | |||||
#include "src/cuda/tqt/opr_impl.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace cuda { | namespace cuda { | ||||
@@ -0,0 +1,30 @@ | |||||
/** | |||||
* \file dnn/src/cuda/lsq/kern.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "./kern.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
#define cb(_dtype) \ | |||||
INST_RUN_ELEMWISE(LSQKernOp<DTypeTrait<_dtype>::ctype>, \ | |||||
DTypeTrait<_dtype>::ctype, 3); \ | |||||
INST_RUN_ELEMWISE(LSQBwdKernOp<DTypeTrait<_dtype>::ctype>, \ | |||||
DTypeTrait<_dtype>::ctype, 3); \ | |||||
INST_RUN_ELEMWISE(LSQKernOpNonContig<DTypeTrait<_dtype>::ctype>, \ | |||||
DTypeTrait<_dtype>::ctype, 5); \ | |||||
INST_RUN_ELEMWISE(LSQBwdKernOpNonContig<DTypeTrait<_dtype>::ctype>, \ | |||||
DTypeTrait<_dtype>::ctype, 7); | |||||
cb(megdnn::dtype::Float32) | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,126 @@ | |||||
/** | |||||
* \file dnn/src/cuda/lsq/kern.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/cuda/elemwise_helper.cuh" | |||||
#include "src/cuda/utils.cuh" | |||||
#if MEGDNN_CC_HOST | |||||
#include "megdnn/oprs.h" | |||||
#endif | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
template <typename ctype> | |||||
struct LSQKernOp { | |||||
ctype* input; | |||||
ctype* output; | |||||
ctype qmin, qmax; | |||||
__device__ void operator()(uint32_t idx, ctype scale, ctype zero_point, | |||||
ctype grad_scale) { | |||||
ctype x = input[idx] / scale + zero_point; | |||||
x = fmaxf(fminf(x, qmax), qmin); | |||||
x = round(x); | |||||
output[idx] = (x - zero_point) * scale; | |||||
} | |||||
#if MEGDNN_CC_HOST | |||||
LSQKernOp(const TensorND& input, const TensorND& output, | |||||
const LSQ::Param& param) | |||||
: input{input.ptr<ctype>()}, | |||||
output{output.ptr<ctype>()}, | |||||
qmin(param.qmin), | |||||
qmax(param.qmax) {} | |||||
#endif | |||||
}; | |||||
template <typename ctype> | |||||
struct LSQBwdKernOp { | |||||
ctype* diff; | |||||
ctype* input; | |||||
ctype* grad_x; | |||||
ctype* grad_s; | |||||
ctype qmin, qmax; | |||||
__device__ void operator()(uint32_t idx, ctype scale, ctype zero_point, | |||||
ctype grad_scale) { | |||||
ctype x = input[idx] / scale + zero_point; | |||||
bool ind_small = x < qmin; | |||||
bool ind_big = x > qmax; | |||||
bool ind_middle = ind_small ^ ind_big; | |||||
ind_middle = !ind_middle; | |||||
grad_s[idx] = ind_small * qmin + ind_big * qmax + | |||||
ind_middle * (-x + round(x)); | |||||
grad_s[idx] = grad_s[idx] * grad_scale * diff[idx]; | |||||
grad_x[idx] = ind_middle * diff[idx]; | |||||
} | |||||
#if MEGDNN_CC_HOST | |||||
LSQBwdKernOp(const TensorND& diff, const TensorND& input, | |||||
const TensorND& grad_x, const TensorND& grad_s, | |||||
const LSQ::Param& param) | |||||
: diff{diff.ptr<ctype>()}, | |||||
input{input.ptr<ctype>()}, | |||||
grad_x{grad_x.ptr<ctype>()}, | |||||
grad_s{grad_s.ptr<ctype>()}, | |||||
qmin(param.qmin), | |||||
qmax(param.qmax) {} | |||||
#endif | |||||
}; | |||||
template <typename ctype> | |||||
struct LSQKernOpNonContig { | |||||
ctype qmin; | |||||
ctype qmax; | |||||
__device__ void operator()(uint32_t, ctype& output, ctype& input, | |||||
ctype& scale, ctype& zero_point, | |||||
ctype grad_scale) { | |||||
ctype x = input / scale + zero_point; | |||||
x = fmaxf(fminf(x, qmax), qmin); | |||||
x = round(x); | |||||
output = (x - zero_point) * scale; | |||||
} | |||||
#if MEGDNN_CC_HOST | |||||
LSQKernOpNonContig(const LSQ::Param& param) | |||||
: qmin(param.qmin), qmax(param.qmax) {} | |||||
#endif | |||||
}; | |||||
template <typename ctype> | |||||
struct LSQBwdKernOpNonContig { | |||||
ctype qmin; | |||||
ctype qmax; | |||||
__device__ void operator()(uint32_t, ctype& grad_x, ctype& grad_s, | |||||
ctype& diff, ctype& input, ctype& scale, | |||||
ctype& zero_point, ctype grad_scale) { | |||||
ctype x = input / scale + zero_point; | |||||
bool ind_small = x < qmin; | |||||
bool ind_big = x > qmax; | |||||
bool ind_middle = ind_small ^ ind_big; | |||||
ind_middle = !ind_middle; | |||||
grad_s = ind_small * qmin + ind_big * qmax + | |||||
ind_middle * (-x + round(x)); | |||||
grad_s = grad_s * grad_scale * diff; | |||||
grad_x = ind_middle * diff; | |||||
} | |||||
#if MEGDNN_CC_HOST | |||||
LSQBwdKernOpNonContig(const LSQ::Param& param) | |||||
: qmin(param.qmin), qmax(param.qmax) {} | |||||
#endif | |||||
}; | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,151 @@ | |||||
/** | |||||
* \file dnn/src/cuda/lsq/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "./opr_impl.h" | |||||
#include "./kern.cuh" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
void LSQForwardImpl::exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||||
_megdnn_tensor_in zero_point, | |||||
_megdnn_tensor_in grad_scale, | |||||
_megdnn_tensor_out output, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(input.layout, scale.layout, zero_point.layout, grad_scale.layout, | |||||
output.layout, workspace.size); | |||||
if (!input.layout.is_contiguous() || !output.layout.is_contiguous()) | |||||
return exec_noncontig(input, scale, zero_point, grad_scale, output); | |||||
ElemwiseOpParamN<3> ele_param; | |||||
ele_param[0] = scale; | |||||
ele_param[0].layout = ele_param[0].layout.broadcast(input.layout); | |||||
ele_param[1] = zero_point; | |||||
ele_param[1].layout = ele_param[1].layout.broadcast(input.layout); | |||||
ele_param[2] = grad_scale; | |||||
ele_param[2].layout = ele_param[2].layout.broadcast(input.layout); | |||||
ele_param.init_from_given_tensor(); | |||||
auto m_param = param(); | |||||
auto stream = cuda_stream(handle()); | |||||
#define cb(DType) \ | |||||
if (input.layout.dtype == DType()) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | |||||
run_elemwise<LSQKernOp<T>, T, 3>(ele_param, stream, \ | |||||
{input, output, m_param}); \ | |||||
return; \ | |||||
} | |||||
cb(megdnn::dtype::Float32) | |||||
#undef cb | |||||
} | |||||
void LSQForwardImpl::exec_noncontig(_megdnn_tensor_in input, | |||||
_megdnn_tensor_in scale, | |||||
_megdnn_tensor_in zero_point, | |||||
_megdnn_tensor_in grad_scale, | |||||
_megdnn_tensor_out output) { | |||||
ElemwiseOpParamN<5> ele_param; | |||||
ele_param[0] = output; | |||||
ele_param[1] = input; | |||||
ele_param[2] = scale; | |||||
ele_param[2].layout = ele_param[2].layout.broadcast(input.layout); | |||||
ele_param[3] = zero_point; | |||||
ele_param[3].layout = ele_param[3].layout.broadcast(input.layout); | |||||
ele_param[4] = grad_scale; | |||||
ele_param[4].layout = ele_param[4].layout.broadcast(input.layout); | |||||
ele_param.init_from_given_tensor(); | |||||
auto m_param = param(); | |||||
auto stream = cuda_stream(handle()); | |||||
#define cb(DType) \ | |||||
if (input.layout.dtype == DType()) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | |||||
run_elemwise<LSQKernOpNonContig<T>, T, 5>(ele_param, stream, \ | |||||
{m_param}); \ | |||||
return; \ | |||||
} | |||||
cb(megdnn::dtype::Float32) | |||||
#undef cb | |||||
} | |||||
void LSQBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||||
_megdnn_tensor_in scale, | |||||
_megdnn_tensor_in zero_point, | |||||
_megdnn_tensor_in grad_scale, | |||||
_megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(diff.layout, input.layout, scale.layout, zero_point.layout, | |||||
grad_scale.layout, grad_x.layout, grad_s.layout, workspace.size); | |||||
if (!input.layout.is_contiguous() || !diff.layout.is_contiguous() || | |||||
!grad_x.layout.is_contiguous() || !grad_s.layout.is_contiguous()) | |||||
return exec_noncontig(diff, input, scale, zero_point, grad_scale, | |||||
grad_x, grad_s); | |||||
ElemwiseOpParamN<3> ele_param; | |||||
ele_param[0] = scale; | |||||
ele_param[0].layout = ele_param[0].layout.broadcast(input.layout); | |||||
ele_param[1] = zero_point; | |||||
ele_param[1].layout = ele_param[1].layout.broadcast(input.layout); | |||||
ele_param[2] = grad_scale; | |||||
ele_param[2].layout = ele_param[2].layout.broadcast(input.layout); | |||||
ele_param.init_from_given_tensor(); | |||||
auto m_param = param(); | |||||
auto stream = cuda_stream(handle()); | |||||
#define cb(DType) \ | |||||
if (grad_x.layout.dtype == DType()) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | |||||
run_elemwise<LSQBwdKernOp<T>, T, 3>( \ | |||||
ele_param, stream, {diff, input, grad_x, grad_s, m_param}); \ | |||||
return; \ | |||||
} | |||||
cb(megdnn::dtype::Float32) | |||||
#undef cb | |||||
} | |||||
void LSQBackwardImpl::exec_noncontig(_megdnn_tensor_in diff, | |||||
_megdnn_tensor_in input, | |||||
_megdnn_tensor_in scale, | |||||
_megdnn_tensor_in zero_point, | |||||
_megdnn_tensor_in grad_scale, | |||||
_megdnn_tensor_out grad_x, | |||||
_megdnn_tensor_out grad_s) { | |||||
ElemwiseOpParamN<7> ele_param; | |||||
ele_param[0] = grad_x; | |||||
ele_param[1] = grad_s; | |||||
ele_param[2] = diff; | |||||
ele_param[3] = input; | |||||
ele_param[4] = scale; | |||||
ele_param[4].layout = ele_param[4].layout.broadcast(input.layout); | |||||
ele_param[5] = zero_point; | |||||
ele_param[5].layout = ele_param[5].layout.broadcast(input.layout); | |||||
ele_param[6] = grad_scale; | |||||
ele_param[6].layout = ele_param[6].layout.broadcast(input.layout); | |||||
ele_param.init_from_given_tensor(); | |||||
auto m_param = param(); | |||||
auto stream = cuda_stream(handle()); | |||||
#define cb(DType) \ | |||||
if (input.layout.dtype == DType()) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | |||||
run_elemwise<LSQBwdKernOpNonContig<T>, T, 7>(ele_param, stream, \ | |||||
{m_param}); \ | |||||
return; \ | |||||
} | |||||
cb(megdnn::dtype::Float32) | |||||
#undef cb | |||||
} | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,65 @@ | |||||
/** | |||||
* \file dnn/src/cuda/lsq/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/cuda/utils.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
class LSQForwardImpl final : public LSQForward { | |||||
public: | |||||
using LSQForward::LSQForward; | |||||
void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||||
_megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, | |||||
_megdnn_tensor_out output, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, /* input */ | |||||
const TensorLayout&, /* scale */ | |||||
const TensorLayout&, /* zero_point */ | |||||
const TensorLayout&, /* grad_scale */ | |||||
const TensorLayout& /* output */) override { | |||||
return 0; | |||||
} | |||||
private: | |||||
void exec_noncontig(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||||
_megdnn_tensor_in zero_point, | |||||
_megdnn_tensor_in grad_scale, | |||||
_megdnn_tensor_out output); | |||||
}; | |||||
class LSQBackwardImpl final : public LSQBackward { | |||||
public: | |||||
using LSQBackward::LSQBackward; | |||||
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||||
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point, | |||||
_megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, | |||||
_megdnn_tensor_out grad_s, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout& /* diff */, | |||||
const TensorLayout& /* input */, | |||||
const TensorLayout& /* scale */, | |||||
const TensorLayout& /* zero_point */, | |||||
const TensorLayout& /* grad_scale */, | |||||
const TensorLayout& /* grad_x */, | |||||
const TensorLayout& /* grad_s */) override { | |||||
return 0; | |||||
} | |||||
private: | |||||
void exec_noncontig(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||||
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point, | |||||
_megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, | |||||
_megdnn_tensor_out grad_s); | |||||
}; | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -50,6 +50,7 @@ | |||||
#include "src/naive/local/opr_impl.h" | #include "src/naive/local/opr_impl.h" | ||||
#include "src/naive/local_share/opr_impl.h" | #include "src/naive/local_share/opr_impl.h" | ||||
#include "src/naive/lrn/opr_impl.h" | #include "src/naive/lrn/opr_impl.h" | ||||
#include "src/naive/lsq/opr_impl.h" | |||||
#include "src/naive/mask_conv/opr_impl.h" | #include "src/naive/mask_conv/opr_impl.h" | ||||
#include "src/naive/matrix_inverse/opr_impl.h" | #include "src/naive/matrix_inverse/opr_impl.h" | ||||
#include "src/naive/matrix_mul/opr_impl.h" | #include "src/naive/matrix_mul/opr_impl.h" | ||||
@@ -0,0 +1,141 @@ | |||||
/** | |||||
* \file dnn/src/naive/lsq/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/naive/lsq/opr_impl.h" | |||||
#include <cmath> | |||||
#include "megdnn/tensor_iter.h" | |||||
#include "src/common/elemwise_helper.cuh" | |||||
#include "src/common/utils.h" | |||||
#include "src/naive/handle.h" | |||||
namespace { | |||||
using namespace megdnn; | |||||
template <typename T> | |||||
void forward_impl(const ElemwiseOpParamN<5> src, float qmin, float qmax) { | |||||
auto inp = tensor_iter_valonly<T>(src[0]).begin(); | |||||
auto out = tensor_iter_valonly<T>(src[1]).begin(); | |||||
auto scale = tensor_iter_valonly<T>(src[2]).begin(); | |||||
auto zero_point = tensor_iter_valonly<T>(src[3]).begin(); | |||||
auto grad_scale = tensor_iter_valonly<T>(src[4]).begin(); | |||||
size_t total = src[0].layout.total_nr_elems(); | |||||
for (size_t i = 0; i < total; ++i) { | |||||
T x = (*inp) / (*scale) + (*zero_point); | |||||
x = x <= qmin ? qmin : x; | |||||
x = x >= qmax ? qmax : x; | |||||
x = round(x); | |||||
*out = (x - (*zero_point)) * (*scale); | |||||
++inp; | |||||
++out; | |||||
++scale; | |||||
++zero_point; | |||||
++grad_scale; | |||||
} | |||||
} | |||||
template <typename T> | |||||
void backward_impl(const ElemwiseOpParamN<7> src, float qmin, float qmax) { | |||||
auto diff = tensor_iter_valonly<T>(src[0]).begin(); | |||||
auto input = tensor_iter_valonly<T>(src[1]).begin(); | |||||
auto scale = tensor_iter_valonly<T>(src[2]).begin(); | |||||
auto zero_point = tensor_iter_valonly<T>(src[3]).begin(); | |||||
auto grad_scale = tensor_iter_valonly<T>(src[4]).begin(); | |||||
auto grad_x = tensor_iter_valonly<T>(src[5]).begin(); | |||||
auto grad_s = tensor_iter_valonly<T>(src[6]).begin(); | |||||
size_t total = src[0].layout.total_nr_elems(); | |||||
for (size_t i = 0; i < total; ++i) { | |||||
T x = (*input) / (*scale) + (*zero_point); | |||||
bool ind_small = x < qmin; | |||||
bool ind_big = x > qmax; | |||||
bool ind_middle = ind_small ^ ind_big; | |||||
ind_middle = !ind_middle; | |||||
*grad_s = ind_small * qmin + ind_big * qmax + | |||||
ind_middle * (-x + round(x)); | |||||
*grad_s = (*grad_s) * (*grad_scale) * (*diff); | |||||
*grad_x = ind_middle * (*diff); | |||||
++diff; | |||||
++input; | |||||
++scale; | |||||
++zero_point; | |||||
++grad_scale; | |||||
++grad_x; | |||||
++grad_s; | |||||
} | |||||
} | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace naive { | |||||
void LSQForwardImpl::exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||||
_megdnn_tensor_in zero_point, | |||||
_megdnn_tensor_in grad_scale, | |||||
_megdnn_tensor_out output, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(input.layout, scale.layout, zero_point.layout, grad_scale.layout, | |||||
output.layout, workspace.size); | |||||
ElemwiseOpParamN<5> src; | |||||
src[0] = input; | |||||
src[1] = output; | |||||
src[2] = scale; | |||||
src[2].layout = src[2].layout.broadcast(input.layout); | |||||
src[3] = zero_point; | |||||
src[3].layout = src[3].layout.broadcast(input.layout); | |||||
src[4] = grad_scale; | |||||
src[4].layout = src[4].layout.broadcast(input.layout); | |||||
#define cb(DType) \ | |||||
if (input.layout.dtype == DType()) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
forward_impl<T>(src, param().qmin, param().qmax)); \ | |||||
return; \ | |||||
} | |||||
cb(dtype::Float32) | |||||
#undef cb | |||||
} | |||||
void LSQBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||||
_megdnn_tensor_in scale, | |||||
_megdnn_tensor_in zero_point, | |||||
_megdnn_tensor_in grad_scale, | |||||
_megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(diff.layout, input.layout, scale.layout, zero_point.layout, | |||||
grad_scale.layout, grad_x.layout, grad_s.layout, workspace.size); | |||||
ElemwiseOpParamN<7> src; | |||||
src[0] = diff; | |||||
src[1] = input; | |||||
src[2] = scale; | |||||
src[2].layout = src[2].layout.broadcast(input.layout); | |||||
src[3] = zero_point; | |||||
src[3].layout = src[3].layout.broadcast(input.layout); | |||||
src[4] = grad_scale; | |||||
src[4].layout = src[4].layout.broadcast(input.layout); | |||||
src[5] = grad_x; | |||||
src[6] = grad_s; | |||||
#define cb(DType) \ | |||||
if (diff.layout.dtype == DType() && grad_x.layout.dtype == DType() && \ | |||||
input.layout.dtype == DType()) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
backward_impl<T>(src, param().qmin, param().qmax)); \ | |||||
return; \ | |||||
} | |||||
cb(dtype::Float32) | |||||
#undef cb | |||||
} | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,53 @@ | |||||
/** | |||||
* \file dnn/src/naive/lsq/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
class LSQForwardImpl final : public LSQForward { | |||||
public: | |||||
using LSQForward::LSQForward; | |||||
void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||||
_megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, | |||||
_megdnn_tensor_out output, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout& /* input */, | |||||
const TensorLayout& /* scale */, | |||||
const TensorLayout& /* zero_point */, | |||||
const TensorLayout& /* grad_scale */, | |||||
const TensorLayout& /* output */) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
class LSQBackwardImpl final : public LSQBackward { | |||||
public: | |||||
using LSQBackward::LSQBackward; | |||||
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||||
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point, | |||||
_megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, | |||||
_megdnn_tensor_out grad_s, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout& /* diff */, | |||||
const TensorLayout& /* input */, | |||||
const TensorLayout& /* scale */, | |||||
const TensorLayout& /* zero_point */, | |||||
const TensorLayout& /* grad_scale */, | |||||
const TensorLayout& /* grad_x */, | |||||
const TensorLayout& /* grad_s */) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,53 @@ | |||||
/** | |||||
* \file dnn/test/common/lsq.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/basic_types.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
namespace lsq { | |||||
struct TestArg { | |||||
param::LSQ param; | |||||
TensorShape ishape; | |||||
TensorShape scale_shape; | |||||
TensorShape zeropoint_shape; | |||||
TensorShape gradscale_shape; | |||||
TestArg(param::LSQ param, TensorShape ishape, TensorShape scale_shape, | |||||
TensorShape zeropoint_shape, TensorShape gradscale_shape) | |||||
: param(param), | |||||
ishape(ishape), | |||||
scale_shape(scale_shape), | |||||
zeropoint_shape(zeropoint_shape), | |||||
gradscale_shape(gradscale_shape) {} | |||||
}; | |||||
inline std::vector<TestArg> get_args() { | |||||
std::vector<TestArg> args; | |||||
param::LSQ cur_param; | |||||
cur_param.qmin = -127; | |||||
cur_param.qmax = 127; | |||||
for (size_t i = 10; i < 30; i += 2) { | |||||
args.emplace_back(cur_param, TensorShape{10, 64, i, i}, TensorShape{1}, | |||||
TensorShape{1}, TensorShape{1}); | |||||
} | |||||
return args; | |||||
} | |||||
} // namespace lsq | |||||
} // namespace test | |||||
} // namespace megdnn |
@@ -0,0 +1,110 @@ | |||||
/** | |||||
* \file dnn/test/cuda/lsq.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "test/common/lsq.h" | |||||
#include "megdnn/oprs.h" | |||||
#include "test/common/checker.h" | |||||
#include "test/cuda/fixture.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
using namespace lsq; | |||||
TEST_F(CUDA, LSQ) { | |||||
std::vector<TestArg> args = get_args(); | |||||
auto dtype = dtype::Float32(); | |||||
for (auto&& arg : args) { | |||||
auto param = arg.param; | |||||
auto ishape = arg.ishape; | |||||
auto scale_shape = arg.scale_shape; | |||||
auto zeropoint_shape = arg.zeropoint_shape; | |||||
auto gradscale_shape = arg.gradscale_shape; | |||||
Checker<LSQForward> checker(handle_cuda()); | |||||
checker.set_param(param) | |||||
.set_dtype(0, dtype) | |||||
.set_dtype(1, dtype) | |||||
.set_dtype(2, dtype) | |||||
.set_dtype(3, dtype) | |||||
.set_dtype(4, dtype) | |||||
.execs({ishape, scale_shape, zeropoint_shape, gradscale_shape, | |||||
ishape}); | |||||
} | |||||
// test noncontiguous layout | |||||
for (auto&& arg : args) { | |||||
auto param = arg.param; | |||||
auto ishape = arg.ishape; | |||||
auto sshape = arg.scale_shape; | |||||
auto zeropoint_shape = arg.zeropoint_shape; | |||||
auto gradscale_shape = arg.gradscale_shape; | |||||
Checker<LSQForward> checker(handle_cuda()); | |||||
TensorLayout ilayout( | |||||
ishape, | |||||
{(long int)(ishape[1] * ishape[2] * ishape[3] * 2), | |||||
(long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1}, | |||||
dtype::Float32()); | |||||
checker.set_param(param).execl({ilayout, | |||||
{sshape, dtype::Float32()}, | |||||
{zeropoint_shape, dtype::Float32()}, | |||||
{gradscale_shape, dtype::Float32()}, | |||||
ilayout}); | |||||
} | |||||
} | |||||
TEST_F(CUDA, LSQ_BACKWARD) { | |||||
std::vector<TestArg> args = get_args(); | |||||
auto dtype = dtype::Float32(); | |||||
for (auto&& arg : args) { | |||||
auto param = arg.param; | |||||
auto ishape = arg.ishape; | |||||
auto scale_shape = arg.scale_shape; | |||||
auto zeropoint_shape = arg.zeropoint_shape; | |||||
auto gradscale_shape = arg.gradscale_shape; | |||||
Checker<LSQBackward> checker(handle_cuda()); | |||||
checker.set_param(param) | |||||
.set_dtype(0, dtype) | |||||
.set_dtype(1, dtype) | |||||
.set_dtype(2, dtype) | |||||
.set_dtype(3, dtype) | |||||
.set_dtype(4, dtype) | |||||
.set_dtype(5, dtype) | |||||
.set_dtype(6, dtype) | |||||
.execs({ishape, ishape, scale_shape, zeropoint_shape, | |||||
gradscale_shape, ishape, ishape}); | |||||
} | |||||
// test noncontiguous layout | |||||
for (auto&& arg : args) { | |||||
auto param = arg.param; | |||||
auto ishape = arg.ishape; | |||||
auto sshape = arg.scale_shape; | |||||
auto zeropoint_shape = arg.zeropoint_shape; | |||||
auto gradscale_shape = arg.gradscale_shape; | |||||
Checker<LSQBackward> checker(handle_cuda()); | |||||
TensorLayout ilayout( | |||||
ishape, | |||||
{(long int)(ishape[1] * ishape[2] * ishape[3] * 2), | |||||
(long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1}, | |||||
dtype::Float32()); | |||||
checker.set_param(param).execl({ilayout, | |||||
ilayout, | |||||
{sshape, dtype::Float32()}, | |||||
{zeropoint_shape, dtype::Float32()}, | |||||
{gradscale_shape, dtype::Float32()}, | |||||
ilayout, | |||||
ilayout}); | |||||
} | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn |
@@ -0,0 +1,45 @@ | |||||
/** | |||||
* \file dnn/test/naive/sliding_window_transpose.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "test/naive/fixture.h" | |||||
#include "megdnn/oprs/nn.h" | |||||
#include "test/common/checker.h" | |||||
using namespace megdnn; | |||||
using namespace test; | |||||
TEST_F(NAIVE, LSQ_FORWARD) { | |||||
Checker<LSQ> checker(handle(), /* check_dispatch */ false); | |||||
param::LSQ param; | |||||
param.qmin = -127; | |||||
param.qmax = 127; | |||||
TensorND input = | |||||
TensorValue({2, 2, 2, 2}, dtype::Float32(), | |||||
{0, 1, 3, 4, 1, 2, 4, 5, 3, 4, 6, 7, 4, 5, 7, 8}); | |||||
TensorND scale_shape = TensorValue({1}, dtype::Float32(), {2}); | |||||
TensorND zero_point = TensorValue({1}, dtype::Float32(), {1}); | |||||
TensorND grad_scale = TensorValue({1}, dtype::Float32(), {0.5}); | |||||
TensorND output = | |||||
TensorValue({2, 2, 2, 2}, dtype::Float32(), | |||||
{0, 2, 4, 4, 2, 2, 4, 6, 4, 4, 6, 8, 4, 6, 8, 8}); | |||||
checker.set_param(param).exect( | |||||
Testcase{input, scale_shape, zero_point, grad_scale, {}}, | |||||
Testcase{{}, {}, {}, {}, output}); | |||||
} |
@@ -6,7 +6,7 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .fake_quant import TQT, FakeQuantize | |||||
from .fake_quant import LSQ, TQT, FakeQuantize | |||||
from .observer import ( | from .observer import ( | ||||
ExponentialMovingAverageObserver, | ExponentialMovingAverageObserver, | ||||
HistogramObserver, | HistogramObserver, | ||||
@@ -12,13 +12,15 @@ from .. import functional as F | |||||
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes | from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from ..module import Module | from ..module import Module | ||||
from ..tensor import Parameter | |||||
from ..tensor import Parameter, Tensor | |||||
from .utils import ( | from .utils import ( | ||||
LSQParams, | |||||
QParams, | QParams, | ||||
QParamsModuleMixin, | QParamsModuleMixin, | ||||
QuantMode, | QuantMode, | ||||
create_qparams, | create_qparams, | ||||
fake_quant_tensor, | fake_quant_tensor, | ||||
lsq_forward, | |||||
tqt_forward, | tqt_forward, | ||||
) | ) | ||||
@@ -117,3 +119,58 @@ class FakeQuantize(_FakeQuantize): | |||||
qparams.dtype_meta, self.dtype | qparams.dtype_meta, self.dtype | ||||
) | ) | ||||
return fake_quant_tensor(inp, qparams) | return fake_quant_tensor(inp, qparams) | ||||
class LSQ(_FakeQuantize, QParamsModuleMixin): | |||||
r""" | |||||
LSQ: https://arxiv.org/pdf/1902.08153.pdf Estimating and scaling the | |||||
task loss gradient at each weight and activation layer's quantizer step size | |||||
:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target | |||||
quantization dtype of input. | |||||
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``. | |||||
:param eps:a small value to avoid division by zero. Default: 1e-5 | |||||
""" | |||||
def init( | |||||
self, | |||||
dtype: Union[str, QuantDtypeMeta], | |||||
enable: bool = True, | |||||
eps: float = 1e-5, | |||||
**kwargs | |||||
): | |||||
super().__init__(dtype=dtype, enable=enable, **kwargs) | |||||
self.eps = Tensor(eps, dtype="float32") | |||||
self.step_size = Parameter(1.0, dtype="float32") | |||||
def set_qparams(self, qparams: LSQParams): | |||||
self.mode = qparams.mode | |||||
if qparams.mode == QuantMode.ASYMMERTIC: | |||||
self.zero_point = qparams.zero_point | |||||
else: | |||||
self.zero_point = Tensor([0.0], dtype="float32") | |||||
if qparams.scale is None: | |||||
raise AssertionError("Can not get an initialized scale") | |||||
init_step_size = qparams.scale | |||||
if init_step_size < self.eps: | |||||
init_step_size = 0 | |||||
else: | |||||
init_step_size = init_step_size - self.eps | |||||
self.step_size = Parameter(init_step_size, dtype="float32") | |||||
self.grad_scale = qparams.grad_scale | |||||
def fake_quant_forward(self, inp, qparams: LSQParams = None): | |||||
step_size = F.abs(self.step_size) + self.eps | |||||
return lsq_forward( | |||||
self.qmin, self.qmax, inp, step_size, self.zero_point, self.grad_scale | |||||
) | |||||
def get_qparams(self): | |||||
return LSQParams( | |||||
mode=self.mode, | |||||
dtype_meta=self.dtype, | |||||
scale=F.abs(self.step_size.detach()) + self.eps, | |||||
zero_point=self.zero_point, | |||||
grad_scale=self.grad_scale, | |||||
) |
@@ -43,6 +43,16 @@ def tqt_forward(qmin, qmax, inp, scale): | |||||
return output | return output | ||||
def lsq_forward(qmin, qmax, inp, step_size, zero_point=None, scale_grad=None): | |||||
if zero_point is None: | |||||
zero_point = Tensor([0.0], dtype=np.float32) | |||||
if scale_grad is None: | |||||
scale_grad = Tensor([1.0], dtype=np.float32) | |||||
op = builtin.LSQ(qmin=qmin, qmax=qmax) | |||||
(output,) = apply(op, inp, step_size, zero_point, scale_grad) | |||||
return output | |||||
def register_method_to_class(cls): | def register_method_to_class(cls): | ||||
def decorator(func): | def decorator(func): | ||||
@wraps(func) | @wraps(func) | ||||
@@ -105,6 +115,47 @@ class QParams: | |||||
return "QParams({})".format(content) | return "QParams({})".format(content) | ||||
class LSQParams: | |||||
""" | |||||
To standardize LSQ's qparams format. If custom | |||||
qparams is needed, inherit this class and add custom ``__slots__``. | |||||
""" | |||||
__slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale" | |||||
def __init__( | |||||
self, | |||||
mode: QuantMode, | |||||
dtype_meta: QuantDtypeMeta, | |||||
scale: Tensor, | |||||
zero_point: Tensor, | |||||
grad_scale: Tensor, | |||||
): | |||||
self.mode = mode | |||||
self.dtype_meta = dtype_meta | |||||
self.scale = scale | |||||
self.zero_point = zero_point | |||||
self.grad_scale = grad_scale | |||||
def update(self, lsqparams: "LSQParams"): | |||||
for key in self.__slots__: | |||||
setattr(self, key, getattr(lsqparams, key)) | |||||
def __eq__(self, other): | |||||
if len(self.__slots__) != len(other.__slots__): | |||||
return False | |||||
for key in self.__slots__: | |||||
if not hasattr(other, key) or getattr(self, key) != getattr(other, key): | |||||
return False | |||||
return True | |||||
def __repr__(self): | |||||
content = ", ".join( | |||||
["{}={}".format(key, getattr(self, key)) for key in self.__slots__] | |||||
) | |||||
return "LSQParams({})".format(content) | |||||
class QParamsModuleMixin(abc.ABC): | class QParamsModuleMixin(abc.ABC): | ||||
def get_quantized_dtype(self): | def get_quantized_dtype(self): | ||||
qparams = self.get_qparams() | qparams = self.get_qparams() | ||||
@@ -10,6 +10,7 @@ import numpy as np | |||||
import pytest | import pytest | ||||
import megengine as mge | import megengine as mge | ||||
import megengine.functional as F | |||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.core.autodiff.grad import Function, Grad | from megengine.core.autodiff.grad import Function, Grad | ||||
from megengine.core.tensor.dtype import QuantDtypeMeta | from megengine.core.tensor.dtype import QuantDtypeMeta | ||||
@@ -19,6 +20,7 @@ from megengine.quantization.utils import ( | |||||
QuantMode, | QuantMode, | ||||
create_qparams, | create_qparams, | ||||
fake_quant_tensor, | fake_quant_tensor, | ||||
lsq_forward, | |||||
tqt_forward, | tqt_forward, | ||||
) | ) | ||||
@@ -150,3 +152,78 @@ def test_fakequant(): | |||||
zero_point = tensor(1.0 * np.ones((1, 32, 1, 1)), dtype=np.float32) | zero_point = tensor(1.0 * np.ones((1, 32, 1, 1)), dtype=np.float32) | ||||
scale = tensor(4.0 * np.ones((1, 32, 1, 1)), dtype=np.float32) | scale = tensor(4.0 * np.ones((1, 32, 1, 1)), dtype=np.float32) | ||||
run(zero_point, scale) | run(zero_point, scale) | ||||
class LSQ_numpy: | |||||
def __init__(self, lowerbound, upperbound): | |||||
super().__init__() | |||||
self.lowerbound = lowerbound | |||||
self.upperbound = upperbound | |||||
def forward(self, inp, scale, zero_point, grad_scale): | |||||
inp_scaled = inp / scale + zero_point | |||||
inp_clipped = np.maximum( | |||||
np.minimum(inp_scaled, self.upperbound), self.lowerbound | |||||
) | |||||
inp_rounded = np.floor(inp_clipped + 0.5) | |||||
inp_flq = (inp_rounded - zero_point) * scale | |||||
self.saved_tensors = (inp_scaled, inp_rounded, scale, grad_scale) | |||||
return inp_flq | |||||
def backward(self, grad_inp_flq): | |||||
(inp_scaled, inp_rounded, scale, grad_scale) = self.saved_tensors | |||||
ind_small = inp_scaled < self.lowerbound | |||||
ind_big = inp_scaled > self.upperbound | |||||
ind_middle = np.logical_xor(ind_small, ind_big) | |||||
ind_middle = np.abs(ind_middle - 1) | |||||
grad_s = ( | |||||
ind_small * self.lowerbound | |||||
+ ind_big * self.upperbound | |||||
+ ind_middle * (-inp_scaled + inp_rounded) | |||||
) | |||||
grad_s = grad_s * grad_scale * grad_inp_flq | |||||
grad_s = grad_s.sum() | |||||
grad_inp = grad_inp_flq * ind_middle | |||||
return grad_inp, grad_s | |||||
def test_lsq(): | |||||
def preprocess(scale, eps): | |||||
scale = np.array([0]) if scale < eps else scale - eps | |||||
return np.abs(scale) + eps | |||||
g = [] | |||||
def cb(grad): | |||||
g.append(grad) | |||||
x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32") | |||||
s = np.random.rand(1) | |||||
eps = np.array([1e-5], dtype="float32") | |||||
s = preprocess(s, eps) | |||||
zero_point = np.array([1.0], dtype="float32") | |||||
grad_s = np.array([2.0], dtype="float32") | |||||
g_y = np.ones(shape=(1, 2, 3, 4), dtype="float32") | |||||
n = LSQ_numpy(-127, 127) | |||||
y_np = n.forward(x, s, zero_point, grad_s) | |||||
g_x_np, g_s_np = n.backward(g_y) | |||||
x = mge.tensor(x, dtype="float32") | |||||
s = mge.tensor(s, dtype="float32") | |||||
zero_point = mge.tensor(zero_point, dtype="float32") | |||||
grad_s = mge.tensor(grad_s, dtype="float32") | |||||
g_y = mge.tensor(g_y, dtype="float32") | |||||
grad = Grad().wrt(x, s, callback=cb) | |||||
y = lsq_forward(-127, 127, x, s, zero_point, grad_s) | |||||
grad(y, g_y) | |||||
g_x, g_s = g | |||||
np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-7, atol=1e-7) | |||||
np.testing.assert_allclose(g_x.numpy(), g_x_np, rtol=1e-7, atol=1e-7) | |||||
np.testing.assert_allclose(g_s.numpy(), g_s_np, rtol=5e-7, atol=5e-7) |
@@ -6,23 +6,26 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
// FIXME: split this file into separate files for each specialized op | // FIXME: split this file into separate files for each specialized op | ||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/opr/dnn/convolution.h" | |||||
#include "megbrain/opr/basic_arith.h" | |||||
#include "megbrain/opr/blas.h" | |||||
#include "megbrain/opr/dnn/adaptive_pooling.h" | #include "megbrain/opr/dnn/adaptive_pooling.h" | ||||
#include "megbrain/opr/dnn/convolution.h" | |||||
#include "megbrain/opr/dnn/correlation.h" | |||||
#include "megbrain/opr/dnn/fake_quant.h" | #include "megbrain/opr/dnn/fake_quant.h" | ||||
#include "megbrain/opr/dnn/tqt.h" | |||||
#include "megbrain/opr/dnn/pooling.h" | |||||
#include "megbrain/opr/dnn/images2neibs.h" | |||||
#include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
#include "megbrain/opr/dnn/lsq.h" | |||||
#include "megbrain/opr/dnn/pooling.h" | |||||
#include "megbrain/opr/dnn/roi_align.h" | #include "megbrain/opr/dnn/roi_align.h" | ||||
#include "megbrain/opr/dnn/correlation.h" | |||||
#include "megbrain/opr/dnn/roi_pooling.h" | #include "megbrain/opr/dnn/roi_pooling.h" | ||||
#include "megbrain/opr/basic_arith.h" | |||||
#include "megbrain/opr/blas.h" | |||||
#include "megbrain/opr/dnn/tqt.h" | |||||
#include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
#include "megbrain/opr/indexing.h" | #include "megbrain/opr/indexing.h" | ||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
@@ -32,40 +35,38 @@ | |||||
#include "megbrain/opr/tensor_gen.h" | #include "megbrain/opr/tensor_gen.h" | ||||
#include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
#include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
#include "megbrain/opr/dnn/images2neibs.h" | |||||
#include "../op_trait.h" | #include "../op_trait.h" | ||||
namespace mgb::imperative { | namespace mgb::imperative { | ||||
namespace { namespace dimshuffle { | |||||
namespace { | |||||
namespace dimshuffle { | |||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | ||||
auto* node = &node_->cast_final_safe<opr::Dimshuffle>(); | auto* node = &node_->cast_final_safe<opr::Dimshuffle>(); | ||||
std::vector<int> pattern(node->param().pattern_len); | std::vector<int> pattern(node->param().pattern_len); | ||||
for (size_t i = 0; i < node->param().pattern_len; ++ i) { | |||||
for (size_t i = 0; i < node->param().pattern_len; ++i) { | |||||
pattern[i] = node->param().pattern[i]; | pattern[i] = node->param().pattern[i]; | ||||
} | } | ||||
return Dimshuffle::make(pattern); | return Dimshuffle::make(pattern); | ||||
} | } | ||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& ds = static_cast<const Dimshuffle&>(def); | auto&& ds = static_cast<const Dimshuffle&>(def); | ||||
OperatorNodeConfig config{ds.make_name()}; | OperatorNodeConfig config{ds.make_name()}; | ||||
return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config); | return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config); | ||||
} | } | ||||
OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) | OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) | ||||
.make_from_op_node(make_from_op_node) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // dimshuffle | |||||
.make_from_op_node(make_from_op_node) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace dimshuffle | |||||
} // namespace | |||||
namespace { namespace add_axis { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace add_axis { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& add_axis = static_cast<const AddAxis&>(def); | auto&& add_axis = static_cast<const AddAxis&>(def); | ||||
using Desc = opr::AxisAddRemove::AxisDesc; | using Desc = opr::AxisAddRemove::AxisDesc; | ||||
std::vector<Desc> param; | std::vector<Desc> param; | ||||
@@ -76,15 +77,13 @@ auto apply_on_var_node( | |||||
return opr::AxisAddRemove::make(inputs[0], param, config); | return opr::AxisAddRemove::make(inputs[0], param, config); | ||||
} | } | ||||
OP_TRAIT_REG(AddAxis, AddAxis) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // add_axis | |||||
OP_TRAIT_REG(AddAxis, AddAxis).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace add_axis | |||||
} // namespace | |||||
namespace { namespace remove_axis { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace remove_axis { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& remove_axis = static_cast<const RemoveAxis&>(def); | auto&& remove_axis = static_cast<const RemoveAxis&>(def); | ||||
using Desc = opr::AxisAddRemove::AxisDesc; | using Desc = opr::AxisAddRemove::AxisDesc; | ||||
std::vector<Desc> param; | std::vector<Desc> param; | ||||
@@ -96,36 +95,35 @@ auto apply_on_var_node( | |||||
} | } | ||||
OP_TRAIT_REG(RemoveAxis, RemoveAxis) | OP_TRAIT_REG(RemoveAxis, RemoveAxis) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // remove_axis | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace remove_axis | |||||
} // namespace | |||||
namespace { namespace top_k { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace top_k { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& topk = static_cast<const TopK&>(def); | auto&& topk = static_cast<const TopK&>(def); | ||||
OperatorNodeConfig config{topk.make_name()}; | OperatorNodeConfig config{topk.make_name()}; | ||||
return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0] | return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0] | ||||
.node()->owner_opr(); | |||||
.node() | |||||
->owner_opr(); | |||||
} | } | ||||
OP_TRAIT_REG(TopK, TopK) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // top_k | |||||
OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace top_k | |||||
} // namespace | |||||
namespace { namespace reduce { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace reduce { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& reduce = static_cast<const Reduce&>(def); | auto&& reduce = static_cast<const Reduce&>(def); | ||||
OperatorNodeConfig config{reduce.make_name()}; | OperatorNodeConfig config{reduce.make_name()}; | ||||
if (inputs.size() > 1) { | if (inputs.size() > 1) { | ||||
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); | return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); | ||||
} else { | } else { | ||||
return opr::Reduce::make( | |||||
inputs[0], reduce.param(), (cg::VarNode*)nullptr, config); | |||||
return opr::Reduce::make(inputs[0], reduce.param(), | |||||
(cg::VarNode*)nullptr, config); | |||||
} | } | ||||
} | } | ||||
@@ -135,86 +133,92 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
} | } | ||||
OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) | OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) | ||||
.make_from_op_node(make_from_op_node) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // reduce | |||||
.make_from_op_node(make_from_op_node) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace reduce | |||||
} // namespace | |||||
namespace { namespace adaptive_pooling { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace adaptive_pooling { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& pool = static_cast<const AdaptivePooling&>(def); | auto&& pool = static_cast<const AdaptivePooling&>(def); | ||||
OperatorNodeConfig config{pool.make_name()}; | OperatorNodeConfig config{pool.make_name()}; | ||||
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config); | |||||
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), | |||||
config); | |||||
} | } | ||||
OP_TRAIT_REG(AdaptivePooling, AdaptivePooling) | OP_TRAIT_REG(AdaptivePooling, AdaptivePooling) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // adaptive_pooling | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace adaptive_pooling | |||||
} // namespace | |||||
namespace { namespace conv_bias { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace conv_bias { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& conv = static_cast<const ConvBias&>(def); | auto&& conv = static_cast<const ConvBias&>(def); | ||||
cg::OperatorNodeConfig config{conv.dtype}; | cg::OperatorNodeConfig config{conv.dtype}; | ||||
config.name(conv.make_name()); | config.name(conv.make_name()); | ||||
if (inputs.size() == 2) { | if (inputs.size() == 2) { | ||||
return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||||
return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), | |||||
conv.policy(), config); | |||||
} else if (inputs.size() == 3) { | } else if (inputs.size() == 3) { | ||||
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||||
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], | |||||
conv.param(), conv.policy(), config); | |||||
} else if (inputs.size() == 4) { | } else if (inputs.size() == 4) { | ||||
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config); | |||||
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], | |||||
conv.param(), conv.policy(), config); | |||||
} | } | ||||
mgb_assert(0); | mgb_assert(0); | ||||
} | } | ||||
OP_TRAIT_REG(ConvBias, ConvBias) | OP_TRAIT_REG(ConvBias, ConvBias) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // conv_bias | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace conv_bias | |||||
} // namespace | |||||
namespace { namespace batch_conv_bias { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace batch_conv_bias { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& conv = static_cast<const BatchConvBias&>(def); | auto&& conv = static_cast<const BatchConvBias&>(def); | ||||
cg::OperatorNodeConfig config{conv.dtype}; | cg::OperatorNodeConfig config{conv.dtype}; | ||||
config.name(conv.make_name()); | config.name(conv.make_name()); | ||||
if (inputs.size() == 2) { | if (inputs.size() == 2) { | ||||
return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||||
return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), | |||||
conv.policy(), config); | |||||
} else if (inputs.size() == 3) { | } else if (inputs.size() == 3) { | ||||
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||||
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], | |||||
conv.param(), conv.policy(), config); | |||||
} else if (inputs.size() == 4) { | } else if (inputs.size() == 4) { | ||||
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config); | |||||
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], | |||||
inputs[3], conv.param(), conv.policy(), | |||||
config); | |||||
} | } | ||||
mgb_assert(0); | mgb_assert(0); | ||||
} | } | ||||
OP_TRAIT_REG(BatchConvBias, BatchConvBias) | OP_TRAIT_REG(BatchConvBias, BatchConvBias) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // batch_conv_bias | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace batch_conv_bias | |||||
} // namespace | |||||
namespace { namespace pooling { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace pooling { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& pool = static_cast<const Pooling&>(def); | auto&& pool = static_cast<const Pooling&>(def); | ||||
OperatorNodeConfig config{pool.make_name()}; | OperatorNodeConfig config{pool.make_name()}; | ||||
return opr::Pooling::make(inputs[0], pool.param(), config); | return opr::Pooling::make(inputs[0], pool.param(), config); | ||||
} | } | ||||
OP_TRAIT_REG(Pooling, Pooling) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // pooling | |||||
OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace pooling | |||||
} // namespace | |||||
namespace { namespace matrix_mul { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace matrix_mul { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& matmul = static_cast<const MatrixMul&>(def); | auto&& matmul = static_cast<const MatrixMul&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
OperatorNodeConfig config{matmul.make_name()}; | OperatorNodeConfig config{matmul.make_name()}; | ||||
@@ -222,14 +226,14 @@ auto apply_on_var_node( | |||||
matmul.policy(), config); | matmul.policy(), config); | ||||
} | } | ||||
OP_TRAIT_REG(MatrixMul, MatrixMul) | OP_TRAIT_REG(MatrixMul, MatrixMul) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // matrix_mul | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace matrix_mul | |||||
} // namespace | |||||
namespace { namespace batched_matrix_mul { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace batched_matrix_mul { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& matmul = static_cast<const BatchedMatrixMul&>(def); | auto&& matmul = static_cast<const BatchedMatrixMul&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
OperatorNodeConfig config{matmul.make_name()}; | OperatorNodeConfig config{matmul.make_name()}; | ||||
@@ -237,166 +241,155 @@ auto apply_on_var_node( | |||||
matmul.policy(), config); | matmul.policy(), config); | ||||
} | } | ||||
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // batched_matrix_mul | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace batched_matrix_mul | |||||
} // namespace | |||||
namespace { namespace dot { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace dot { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = def.cast_final_safe<Dot>(); | auto&& op = def.cast_final_safe<Dot>(); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::Dot::make(inputs[0], inputs[1], config); | return opr::Dot::make(inputs[0], inputs[1], config); | ||||
} | } | ||||
OP_TRAIT_REG(Dot, Dot) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // dot | |||||
OP_TRAIT_REG(Dot, Dot).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace dot | |||||
} // namespace | |||||
namespace { namespace argsort { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace argsort { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& argsort = static_cast<const Argsort&>(def); | auto&& argsort = static_cast<const Argsort&>(def); | ||||
OperatorNodeConfig config{argsort.make_name()}; | OperatorNodeConfig config{argsort.make_name()}; | ||||
return opr::Argsort::make(inputs[0], argsort.param(), config); | return opr::Argsort::make(inputs[0], argsort.param(), config); | ||||
} | } | ||||
OP_TRAIT_REG(Argsort, Argsort) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // argsort | |||||
OP_TRAIT_REG(Argsort, Argsort).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace argsort | |||||
} // namespace | |||||
namespace { namespace argmax { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace argmax { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& argmax = static_cast<const Argmax&>(def); | auto&& argmax = static_cast<const Argmax&>(def); | ||||
OperatorNodeConfig config{argmax.make_name()}; | OperatorNodeConfig config{argmax.make_name()}; | ||||
return opr::Argmax::make(inputs[0], argmax.param(), config); | return opr::Argmax::make(inputs[0], argmax.param(), config); | ||||
} | } | ||||
OP_TRAIT_REG(Argmax, Argmax) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // argmax | |||||
OP_TRAIT_REG(Argmax, Argmax).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace argmax | |||||
} // namespace | |||||
namespace { namespace argmin { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace argmin { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& argmin = static_cast<const Argmin&>(def); | auto&& argmin = static_cast<const Argmin&>(def); | ||||
OperatorNodeConfig config{argmin.make_name()}; | OperatorNodeConfig config{argmin.make_name()}; | ||||
return opr::Argmin::make(inputs[0], argmin.param(), config); | return opr::Argmin::make(inputs[0], argmin.param(), config); | ||||
} | } | ||||
OP_TRAIT_REG(Argmin, Argmin) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // argmin | |||||
OP_TRAIT_REG(Argmin, Argmin).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace argmin | |||||
} // namespace | |||||
namespace { namespace warp_perspective { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace warp_perspective { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& warp = static_cast<const WarpPerspective&>(def); | auto&& warp = static_cast<const WarpPerspective&>(def); | ||||
OperatorNodeConfig config{warp.make_name()}; | OperatorNodeConfig config{warp.make_name()}; | ||||
if (inputs.size() == 3) { | if (inputs.size() == 3) { | ||||
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param(), config); | |||||
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], | |||||
warp.param(), config); | |||||
} else { | } else { | ||||
mgb_assert(inputs.size() == 4); | mgb_assert(inputs.size() == 4); | ||||
return opr::WarpPerspective::make( | |||||
inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config); | |||||
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], | |||||
inputs[3], warp.param(), config); | |||||
} | } | ||||
} | } | ||||
OP_TRAIT_REG(WarpPerspective, WarpPerspective) | OP_TRAIT_REG(WarpPerspective, WarpPerspective) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // warp_perspective | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace warp_perspective | |||||
} // namespace | |||||
namespace { namespace group_local { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace group_local { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& local = static_cast<const GroupLocal&>(def); | auto&& local = static_cast<const GroupLocal&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
OperatorNodeConfig config{local.make_name()}; | OperatorNodeConfig config{local.make_name()}; | ||||
return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config); | return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config); | ||||
} | } | ||||
OP_TRAIT_REG(GroupLocal, GroupLocal) | OP_TRAIT_REG(GroupLocal, GroupLocal) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // group_local | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace group_local | |||||
} // namespace | |||||
namespace { namespace indexing_one_hot { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace indexing_one_hot { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const IndexingOneHot&>(def); | auto&& op = static_cast<const IndexingOneHot&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config); | return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config); | ||||
} | } | ||||
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) | OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // indexing_one_hot | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace indexing_one_hot | |||||
} // namespace | |||||
namespace { namespace indexing_set_one_hot { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace indexing_set_one_hot { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const IndexingSetOneHot&>(def); | auto&& op = static_cast<const IndexingSetOneHot&>(def); | ||||
mgb_assert(inputs.size() == 3); | mgb_assert(inputs.size() == 3); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||||
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], | |||||
op.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) | OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // indexing_set_one_hot | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace indexing_set_one_hot | |||||
} // namespace | |||||
namespace { namespace typecvt { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace typecvt { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const TypeCvt&>(def); | auto&& op = static_cast<const TypeCvt&>(def); | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::TypeCvt::make(inputs[0], op.dtype, config); | return opr::TypeCvt::make(inputs[0], op.dtype, config); | ||||
} | } | ||||
OP_TRAIT_REG(TypeCvt, TypeCvt) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // typecvt | |||||
OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace typecvt | |||||
} // namespace | |||||
namespace { namespace concat { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace concat { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const Concat&>(def); | auto&& op = static_cast<const Concat&>(def); | ||||
cg::OperatorNodeConfig config{op.comp_node}; | cg::OperatorNodeConfig config{op.comp_node}; | ||||
config.name(op.make_name()); | config.name(op.make_name()); | ||||
return opr::Concat::make(inputs, op.axis, config); | return opr::Concat::make(inputs, op.axis, config); | ||||
} | } | ||||
OP_TRAIT_REG(Concat, Concat) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // concat | |||||
OP_TRAIT_REG(Concat, Concat).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace concat | |||||
} // namespace | |||||
namespace { namespace copy { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace copy { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const Copy&>(def); | auto&& op = static_cast<const Copy&>(def); | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
cg::OperatorNodeConfig config{op.comp_node}; | cg::OperatorNodeConfig config{op.comp_node}; | ||||
config.name(op.make_name()); | config.name(op.make_name()); | ||||
return opr::Copy::make(inputs[0], config); | return opr::Copy::make(inputs[0], config); | ||||
} | } | ||||
OP_TRAIT_REG(Copy, Copy) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // copy | |||||
OP_TRAIT_REG(Copy, Copy).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace copy | |||||
} // namespace | |||||
namespace { namespace assert_equal { | namespace { namespace assert_equal { | ||||
auto apply_on_var_node( | auto apply_on_var_node( | ||||
@@ -408,81 +401,81 @@ auto apply_on_var_node( | |||||
} else { | } else { | ||||
// workaround for MiniGraph, which only allow one opr in the graph | // workaround for MiniGraph, which only allow one opr in the graph | ||||
mgb_assert(inputs.size() == 3); | mgb_assert(inputs.size() == 3); | ||||
return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {}); | |||||
return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], | |||||
op.param(), {}); | |||||
} | } | ||||
} | } | ||||
OP_TRAIT_REG(AssertEqual, AssertEqual) | OP_TRAIT_REG(AssertEqual, AssertEqual) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // assert_equal | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace assert_equal | |||||
} // namespace | |||||
namespace { namespace roi_align { | |||||
VarNodeArray apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace roi_align { | |||||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const ROIAlign&>(def); | auto&& op = static_cast<const ROIAlign&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
auto* opr = opr::ROIAlign::make( | |||||
inputs[0], inputs[1], op.param(), config).node()->owner_opr(); | |||||
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config) | |||||
.node() | |||||
->owner_opr(); | |||||
return {opr->output(0), opr->output(1)}; | return {opr->output(0), opr->output(1)}; | ||||
} | } | ||||
OP_TRAIT_REG(ROIAlign, ROIAlign) | OP_TRAIT_REG(ROIAlign, ROIAlign) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // roi_align | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace roi_align | |||||
} // namespace | |||||
namespace { namespace correlation { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace correlation { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const Correlation&>(def); | auto&& op = static_cast<const Correlation&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::Correlation::make( | |||||
inputs[0], inputs[1], op.param(), config); | |||||
return opr::Correlation::make(inputs[0], inputs[1], op.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(Correlation, Correlation) | OP_TRAIT_REG(Correlation, Correlation) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // correlation | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace correlation | |||||
} // namespace | |||||
#if MGB_CUDA | #if MGB_CUDA | ||||
namespace { namespace nvof { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace nvof { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const NvOf&>(def); | auto&& op = static_cast<const NvOf&>(def); | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::NvOf::make(inputs[0], op.param(), config); | return opr::NvOf::make(inputs[0], op.param(), config); | ||||
} | } | ||||
OP_TRAIT_REG(NvOf, NvOf) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // nvof | |||||
OP_TRAIT_REG(NvOf, NvOf).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace nvof | |||||
} // namespace | |||||
#endif | #endif | ||||
namespace { namespace linspace { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace linspace { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const Linspace&>(def); | auto&& op = static_cast<const Linspace&>(def); | ||||
mgb_assert(inputs.size() == 3); | mgb_assert(inputs.size() == 3); | ||||
cg::OperatorNodeConfig config{op.comp_node}; | cg::OperatorNodeConfig config{op.comp_node}; | ||||
config.name(op.make_name()); | config.name(op.make_name()); | ||||
return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||||
return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), | |||||
config); | |||||
} | } | ||||
OP_TRAIT_REG(Linspace, Linspace) | OP_TRAIT_REG(Linspace, Linspace) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // linspace | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace linspace | |||||
} // namespace | |||||
namespace { namespace eye { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace eye { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const Eye&>(def); | auto&& op = static_cast<const Eye&>(def); | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
cg::OperatorNodeConfig config{op.comp_node}; | cg::OperatorNodeConfig config{op.comp_node}; | ||||
@@ -490,58 +483,59 @@ auto apply_on_var_node( | |||||
opr::Eye::Param param{op.k, op.dtype.enumv()}; | opr::Eye::Param param{op.k, op.dtype.enumv()}; | ||||
return opr::Eye::make(inputs[0], param, config); | return opr::Eye::make(inputs[0], param, config); | ||||
} | } | ||||
OP_TRAIT_REG(Eye, Eye) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // eye | |||||
OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace eye | |||||
} // namespace | |||||
namespace { namespace roi_pooling { | |||||
VarNodeArray apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace roi_pooling { | |||||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const ROIPooling&>(def); | auto&& op = static_cast<const ROIPooling&>(def); | ||||
mgb_assert(inputs.size() == 3); | mgb_assert(inputs.size() == 3); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
auto* opr = opr::ROIPooling::make( | |||||
inputs[0], inputs[1], inputs[2], op.param(), config | |||||
).node()->owner_opr(); | |||||
auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], | |||||
op.param(), config) | |||||
.node() | |||||
->owner_opr(); | |||||
return {opr->output(0), opr->output(1)}; | return {opr->output(0), opr->output(1)}; | ||||
} | } | ||||
OP_TRAIT_REG(ROIPooling, ROIPooling) | OP_TRAIT_REG(ROIPooling, ROIPooling) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // roi_pooling | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace roi_pooling | |||||
} // namespace | |||||
namespace { namespace remap { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace remap { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const Remap&>(def); | auto&& op = static_cast<const Remap&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::Remap::make(inputs[0], inputs[1], op.param(), config); | return opr::Remap::make(inputs[0], inputs[1], op.param(), config); | ||||
} | } | ||||
OP_TRAIT_REG(Remap, Remap) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // remap | |||||
OP_TRAIT_REG(Remap, Remap).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace remap | |||||
} // namespace | |||||
namespace { | namespace { | ||||
auto get_index( | auto get_index( | ||||
const VarNodeArray& inputs, size_t vidx, | |||||
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) { | |||||
const VarNodeArray& inputs, size_t vidx, | |||||
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) { | |||||
size_t length = mask.size(); | size_t length = mask.size(); | ||||
opr::Subtensor::IndexDesc ret(length); | opr::Subtensor::IndexDesc ret(length); | ||||
for (size_t i = 0; i < length; ++ i) { | |||||
for (size_t i = 0; i < length; ++i) { | |||||
auto&& [axis, begin, end, step, idx] = mask[i]; | auto&& [axis, begin, end, step, idx] = mask[i]; | ||||
ret[i].axis = axis; | ret[i].axis = axis; | ||||
if (idx) { | if (idx) { | ||||
ret[i].idx = inputs[vidx++]; | ret[i].idx = inputs[vidx++]; | ||||
} else { | } else { | ||||
mgb_assert(begin || end || step); | mgb_assert(begin || end || step); | ||||
if (begin) ret[i].begin = inputs[vidx++]; | |||||
if (end) ret[i].end = inputs[vidx++]; | |||||
if (step) ret[i].step = inputs[vidx++]; | |||||
if (begin) | |||||
ret[i].begin = inputs[vidx++]; | |||||
if (end) | |||||
ret[i].end = inputs[vidx++]; | |||||
if (step) | |||||
ret[i].step = inputs[vidx++]; | |||||
} | } | ||||
} | } | ||||
mgb_assert(vidx == inputs.size()); | mgb_assert(vidx == inputs.size()); | ||||
@@ -550,19 +544,19 @@ auto get_index( | |||||
#define IN1 inputs[0] | #define IN1 inputs[0] | ||||
#define IN2 inputs[0], inputs[1] | #define IN2 inputs[0], inputs[1] | ||||
#define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \ | |||||
namespace NAME##_impl { \ | |||||
auto apply_on_var_node( \ | |||||
const OpDef& def, \ | |||||
const VarNodeArray& inputs) { \ | |||||
auto&& op = static_cast<const NAME&>(def); \ | |||||
OperatorNodeConfig config{op.make_name()}; \ | |||||
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \ | |||||
} \ | |||||
OP_TRAIT_REG(NAME, NAME) \ | |||||
.apply_on_var_node(apply_on_var_node) \ | |||||
.fallback(); \ | |||||
} | |||||
#define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \ | |||||
namespace NAME##_impl { \ | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { \ | |||||
auto&& op = static_cast<const NAME&>(def); \ | |||||
OperatorNodeConfig config{op.make_name()}; \ | |||||
return opr::NAME::make(IN##NR_INPUT, \ | |||||
get_index(inputs, NR_INPUT, op.items), \ | |||||
config); \ | |||||
} \ | |||||
OP_TRAIT_REG(NAME, NAME) \ | |||||
.apply_on_var_node(apply_on_var_node) \ | |||||
.fallback(); \ | |||||
} | |||||
FANCY_INDEXING_IMPL(Subtensor, 1) | FANCY_INDEXING_IMPL(Subtensor, 1) | ||||
FANCY_INDEXING_IMPL(SetSubtensor, 2) | FANCY_INDEXING_IMPL(SetSubtensor, 2) | ||||
@@ -580,76 +574,88 @@ FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2) | |||||
#undef FANCY_INDEXING_IMPL | #undef FANCY_INDEXING_IMPL | ||||
#undef IN1 | #undef IN1 | ||||
#undef IN2 | #undef IN2 | ||||
} // anonymous namespace | |||||
} // anonymous namespace | |||||
namespace { namespace fake_quant { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace fake_quant { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const FakeQuant&>(def); | auto&& op = static_cast<const FakeQuant&>(def); | ||||
mgb_assert(inputs.size() == 3); | mgb_assert(inputs.size() == 3); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||||
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), | |||||
config); | |||||
} | } | ||||
OP_TRAIT_REG(FakeQuant, FakeQuant) | OP_TRAIT_REG(FakeQuant, FakeQuant) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // fake_quant | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace fake_quant | |||||
} // namespace | |||||
namespace { namespace tqt { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace tqt { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const TQT&>(def); | auto&& op = static_cast<const TQT&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::TQT::make(inputs[0], inputs[1], op.param(), config); | return opr::TQT::make(inputs[0], inputs[1], op.param(), config); | ||||
} | } | ||||
OP_TRAIT_REG(TQT, TQT) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // tqt | |||||
OP_TRAIT_REG(TQT, TQT).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace tqt | |||||
} // namespace | |||||
namespace { namespace elemwise_multi_type { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace elemwise_multi_type { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const ElemwiseMultiType&>(def); | auto&& op = static_cast<const ElemwiseMultiType&>(def); | ||||
OperatorNodeConfig config{op.dtype}; | OperatorNodeConfig config{op.dtype}; | ||||
config.name(op.make_name()); | config.name(op.make_name()); | ||||
return opr::ElemwiseMultiType::make(inputs, op.param(), config); | return opr::ElemwiseMultiType::make(inputs, op.param(), config); | ||||
} | } | ||||
OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) | OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // elemwise_multi_type | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace elemwise_multi_type | |||||
} // namespace | |||||
namespace { namespace svd { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace svd { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const SVD&>(def); | auto&& op = static_cast<const SVD&>(def); | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::SVD::make(inputs[0], op.param(), config)[0] | return opr::SVD::make(inputs[0], op.param(), config)[0] | ||||
.node()->owner_opr()->usable_output(); | |||||
.node() | |||||
->owner_opr() | |||||
->usable_output(); | |||||
} | } | ||||
OP_TRAIT_REG(SVD, SVD) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // svd | |||||
OP_TRAIT_REG(SVD, SVD).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace svd | |||||
} // namespace | |||||
namespace { namespace images2neibs { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
namespace { | |||||
namespace images2neibs { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const Images2Neibs&>(def); | auto&& op = static_cast<const Images2Neibs&>(def); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::Images2Neibs::make(inputs[0], op.param(), config); | return opr::Images2Neibs::make(inputs[0], op.param(), config); | ||||
} | } | ||||
OP_TRAIT_REG(Images2Neibs, Images2Neibs) | OP_TRAIT_REG(Images2Neibs, Images2Neibs) | ||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // images2neibs | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace images2neibs | |||||
} // namespace | |||||
namespace { | |||||
namespace lsq { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const LSQ&>(def); | |||||
mgb_assert(inputs.size() == 4); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::LSQ::make(inputs[0], inputs[1], inputs[2], inputs[3], | |||||
op.param(), config); | |||||
} | |||||
OP_TRAIT_REG(LSQ, LSQ).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace lsq | |||||
} // namespace | |||||
} // namespace mgb::imperative | |||||
} // namespace mgb::imperative |
@@ -6,22 +6,24 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "./helper.h" | #include "./helper.h" | ||||
#include "megbrain/imperative/backward_graph_opt.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/imperative/ops/opr_attr.h" | |||||
#include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
#include "megbrain/opr/dnn/batch_norm.h" | #include "megbrain/opr/dnn/batch_norm.h" | ||||
#include "megbrain/imperative/ops/opr_attr.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/imperative/backward_graph_opt.h" | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace cg; | using namespace cg; | ||||
using namespace imperative; | using namespace imperative; | ||||
template <typename T> | template <typename T> | ||||
T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, const T& outputs, const T& grads) { | |||||
T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, | |||||
const T& outputs, const T& grads) { | |||||
T ret; | T ret; | ||||
size_t i = 0; | size_t i = 0; | ||||
for (auto&& t : inputs) { | for (auto&& t : inputs) { | ||||
@@ -54,7 +56,9 @@ T expand_grads(const U& bg, const T& outputs) { | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, const T& precomp, const T& inputs, const T& outputs, const T& grads) { | |||||
T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, | |||||
const T& precomp, const T& inputs, | |||||
const T& outputs, const T& grads) { | |||||
T ret = precomp; | T ret = precomp; | ||||
size_t i = 0; | size_t i = 0; | ||||
for (auto&& t : inputs) { | for (auto&& t : inputs) { | ||||
@@ -75,7 +79,8 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, cons | |||||
return ret; | return ret; | ||||
} | } | ||||
SmallVector<TensorPtr> apply_shared_on_physical_tensor(std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) { | |||||
SmallVector<TensorPtr> apply_shared_on_physical_tensor( | |||||
std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) { | |||||
return OpDef::apply_on_physical_tensor(*def, inputs); | return OpDef::apply_on_physical_tensor(*def, inputs); | ||||
} | } | ||||
@@ -83,7 +88,7 @@ TEST(TestImperative, BackwardGraphBasic) { | |||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
SmallVector<HostTensorND> hvs; | SmallVector<HostTensorND> hvs; | ||||
SmallVector<TensorPtr> inputs; | SmallVector<TensorPtr> inputs; | ||||
for(size_t i = 0; i < 2; ++ i) { | |||||
for (size_t i = 0; i < 2; ++i) { | |||||
hvs.push_back(*gen({42})); | hvs.push_back(*gen({42})); | ||||
inputs.push_back(Tensor::make(hvs.back())); | inputs.push_back(Tensor::make(hvs.back())); | ||||
} | } | ||||
@@ -97,7 +102,8 @@ TEST(TestImperative, BackwardGraphBasic) { | |||||
for (auto&& i : inputs) { | for (auto&& i : inputs) { | ||||
input_descs.push_back({i->layout(), i->comp_node()}); | input_descs.push_back({i->layout(), i->comp_node()}); | ||||
} | } | ||||
auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, {true}); | |||||
auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, | |||||
{true}); | |||||
auto&& save_for_backward = result.save_for_backward; | auto&& save_for_backward = result.save_for_backward; | ||||
auto&& input_has_grad = result.input_has_grad; | auto&& input_has_grad = result.input_has_grad; | ||||
@@ -106,9 +112,9 @@ TEST(TestImperative, BackwardGraphBasic) { | |||||
hvs.push_back(*gen({42})); | hvs.push_back(*gen({42})); | ||||
inputs.push_back(Tensor::make(hvs.back())); | inputs.push_back(Tensor::make(hvs.back())); | ||||
mgb_assert(save_for_backward.size() == inputs.size()); | mgb_assert(save_for_backward.size() == inputs.size()); | ||||
for (size_t i = 0; i < inputs.size(); ++ i) { | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (!save_for_backward[i]) { | if (!save_for_backward[i]) { | ||||
inputs[i].reset(); // drop unused tensor | |||||
inputs[i].reset(); // drop unused tensor | |||||
} | } | ||||
} | } | ||||
SmallVector<TensorPtr> backward_graph_inputs; | SmallVector<TensorPtr> backward_graph_inputs; | ||||
@@ -118,13 +124,11 @@ TEST(TestImperative, BackwardGraphBasic) { | |||||
} | } | ||||
} | } | ||||
inputs.clear(); | inputs.clear(); | ||||
auto input_grads = result.backward.apply( | |||||
backward_graph_inputs, | |||||
apply_shared_on_physical_tensor, | |||||
[&](auto&& x){ return x; } | |||||
); | |||||
auto input_grads = result.backward.apply(backward_graph_inputs, | |||||
apply_shared_on_physical_tensor, | |||||
[&](auto&& x) { return x; }); | |||||
mgb_assert(input_grads.size() == input_has_grad.size()); | mgb_assert(input_grads.size() == input_has_grad.size()); | ||||
for (size_t i = 0; i < input_has_grad.size(); ++ i) { | |||||
for (size_t i = 0; i < input_has_grad.size(); ++i) { | |||||
mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | ||||
} | } | ||||
@@ -133,9 +137,10 @@ TEST(TestImperative, BackwardGraphBasic) { | |||||
res.emplace_back(); | res.emplace_back(); | ||||
res.back().copy_from(i->dev_tensor()).sync(); | res.back().copy_from(i->dev_tensor()).sync(); | ||||
} | } | ||||
for (size_t i = 0; i < 42; ++ i) { | |||||
for (size_t j = 0; j < 1; ++ j) { | |||||
ASSERT_EQ(hvs[2].ptr<float>()[i] * hvs[j].ptr<float>()[i], res[j ^ 1].ptr<float>()[i]); | |||||
for (size_t i = 0; i < 42; ++i) { | |||||
for (size_t j = 0; j < 1; ++j) { | |||||
ASSERT_EQ(hvs[2].ptr<float>()[i] * hvs[j].ptr<float>()[i], | |||||
res[j ^ 1].ptr<float>()[i]); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -152,7 +157,8 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||||
SmallVector<LogicalTensorDesc> input_descs; | SmallVector<LogicalTensorDesc> input_descs; | ||||
input_descs.push_back({a->layout(), a->comp_node()}); | input_descs.push_back({a->layout(), a->comp_node()}); | ||||
auto result = OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); | |||||
auto result = | |||||
OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); | |||||
auto&& save_for_backward = result.save_for_backward; | auto&& save_for_backward = result.save_for_backward; | ||||
auto&& input_has_grad = result.input_has_grad; | auto&& input_has_grad = result.input_has_grad; | ||||
@@ -160,9 +166,9 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||||
inputs.push_back(outputs[0]); | inputs.push_back(outputs[0]); | ||||
inputs.push_back(dc); | inputs.push_back(dc); | ||||
mgb_assert(save_for_backward.size() == inputs.size()); | mgb_assert(save_for_backward.size() == inputs.size()); | ||||
for (size_t i = 0; i < inputs.size(); ++ i) { | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (!save_for_backward[i]) { | if (!save_for_backward[i]) { | ||||
inputs[i].reset(); // drop unused tensor | |||||
inputs[i].reset(); // drop unused tensor | |||||
} | } | ||||
} | } | ||||
SmallVector<TensorPtr> backward_graph_inputs; | SmallVector<TensorPtr> backward_graph_inputs; | ||||
@@ -172,19 +178,17 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||||
} | } | ||||
} | } | ||||
inputs.clear(); | inputs.clear(); | ||||
auto input_grads = result.backward.apply( | |||||
backward_graph_inputs, | |||||
apply_shared_on_physical_tensor, | |||||
[&](auto&& x){ return x; } | |||||
); | |||||
auto input_grads = result.backward.apply(backward_graph_inputs, | |||||
apply_shared_on_physical_tensor, | |||||
[&](auto&& x) { return x; }); | |||||
mgb_assert(input_grads.size() == input_has_grad.size()); | mgb_assert(input_grads.size() == input_has_grad.size()); | ||||
for (size_t i = 0; i < input_has_grad.size(); ++ i) { | |||||
for (size_t i = 0; i < input_has_grad.size(); ++i) { | |||||
mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | ||||
} | } | ||||
HostTensorND hv; | HostTensorND hv; | ||||
hv.copy_from(input_grads[0]->dev_tensor()).sync(); | hv.copy_from(input_grads[0]->dev_tensor()).sync(); | ||||
for (size_t i = 0; i < 42; ++ i) { | |||||
for (size_t i = 0; i < 42; ++i) { | |||||
ASSERT_EQ(host_dc->ptr<float>()[i], hv.ptr<float>()[i]); | ASSERT_EQ(host_dc->ptr<float>()[i], hv.ptr<float>()[i]); | ||||
} | } | ||||
} | } | ||||
@@ -192,7 +196,7 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||||
TEST(TestImperative, BatchNormGrad) { | TEST(TestImperative, BatchNormGrad) { | ||||
auto cn = CompNode::load("xpux"); | auto cn = CompNode::load("xpux"); | ||||
using Param = opr::BatchNorm::Param; | using Param = opr::BatchNorm::Param; | ||||
size_t N=2, C=3, H=5, W=5; | |||||
size_t N = 2, C = 3, H = 5, W = 5; | |||||
LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | ||||
LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | ||||
{ | { | ||||
@@ -202,7 +206,8 @@ TEST(TestImperative, BatchNormGrad) { | |||||
param.fwd_mode = Param::FwdMode::TRAINING; | param.fwd_mode = Param::FwdMode::TRAINING; | ||||
attr.param.write_pod(param); | attr.param.write_pod(param); | ||||
OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, | OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, | ||||
{true, true ,true, false, false}, {false, false, false, false, true}); | |||||
{true, true, true, false, false}, | |||||
{false, false, false, false, true}); | |||||
} | } | ||||
{ | { | ||||
auto op = OprAttr::make("BatchNorm"); | auto op = OprAttr::make("BatchNorm"); | ||||
@@ -210,8 +215,8 @@ TEST(TestImperative, BatchNormGrad) { | |||||
Param param; | Param param; | ||||
param.fwd_mode = Param::FwdMode::TRAINING; | param.fwd_mode = Param::FwdMode::TRAINING; | ||||
attr.param.write_pod(param); | attr.param.write_pod(param); | ||||
OpDef::make_backward_graph(attr, {inp, stat, stat}, | |||||
{true, true ,true}, {false, false, true}); | |||||
OpDef::make_backward_graph(attr, {inp, stat, stat}, {true, true, true}, | |||||
{false, false, true}); | |||||
} | } | ||||
} | } | ||||
@@ -220,7 +225,8 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||||
LogicalTensorDesc desc = {TensorLayout(dtype::Float32()), cn}; | LogicalTensorDesc desc = {TensorLayout(dtype::Float32()), cn}; | ||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
auto op = std::shared_ptr<OpDef>(Elemwise::make(Elemwise::Mode::ADD)); | auto op = std::shared_ptr<OpDef>(Elemwise::make(Elemwise::Mode::ADD)); | ||||
auto bg = OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true}); | |||||
auto bg = | |||||
OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true}); | |||||
auto obg = OptimizedBackwardGraphResult(bg); | auto obg = OptimizedBackwardGraphResult(bg); | ||||
ASSERT_EQ(obg.save_for_backward.size(), 4); | ASSERT_EQ(obg.save_for_backward.size(), 4); | ||||
ASSERT_FALSE(obg.save_for_backward[0]); | ASSERT_FALSE(obg.save_for_backward[0]); | ||||
@@ -235,30 +241,30 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||||
auto dc_tn = Tensor::make(*dc_hv); | auto dc_tn = Tensor::make(*dc_hv); | ||||
auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; | auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; | ||||
auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||||
auto grads = expand_grads(bg, bg.backward.apply( | |||||
backward_graph_inputs, | |||||
apply_shared_on_physical_tensor, | |||||
[&](auto&& x){ return x; } | |||||
)); | |||||
auto backward_graph_inputs = | |||||
prepare_backward_graph_inputs<SmallVector<TensorPtr>>( | |||||
bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||||
auto grads = | |||||
expand_grads(bg, bg.backward.apply(backward_graph_inputs, | |||||
apply_shared_on_physical_tensor, | |||||
[&](auto&& x) { return x; })); | |||||
auto precomp = obg.precomp.apply( | |||||
SmallVector<TensorPtr>{a_tn, b_tn, c_tn}, | |||||
apply_shared_on_physical_tensor, | |||||
[&](auto&& x){ return x; } | |||||
); | |||||
auto precomp = obg.precomp.apply(SmallVector<TensorPtr>{a_tn, b_tn, c_tn}, | |||||
apply_shared_on_physical_tensor, | |||||
[&](auto&& x) { return x; }); | |||||
ASSERT_EQ(precomp.size(), 2); | ASSERT_EQ(precomp.size(), 2); | ||||
ASSERT_EQ(precomp[0]->shape().ndim, 1); | ASSERT_EQ(precomp[0]->shape().ndim, 1); | ||||
ASSERT_LE(precomp[0]->shape()[0], 2); | ASSERT_LE(precomp[0]->shape()[0], 2); | ||||
ASSERT_EQ(precomp[1]->shape().ndim, 1); | ASSERT_EQ(precomp[1]->shape().ndim, 1); | ||||
ASSERT_LE(precomp[1]->shape()[0], 2); | ASSERT_LE(precomp[1]->shape()[0], 2); | ||||
auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||||
auto grads2 = expand_grads(obg, obg.backward.apply( | |||||
backward_inputs, | |||||
apply_shared_on_physical_tensor, | |||||
[&](auto&& x){ return x; } | |||||
)); | |||||
auto backward_inputs = | |||||
prepare_optimized_backward_inputs<SmallVector<TensorPtr>>( | |||||
obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||||
auto grads2 = expand_grads( | |||||
obg, | |||||
obg.backward.apply(backward_inputs, apply_shared_on_physical_tensor, | |||||
[&](auto&& x) { return x; })); | |||||
ASSERT_EQ(grads2.size(), 2); | ASSERT_EQ(grads2.size(), 2); | ||||
MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value()); | MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value()); | ||||
@@ -271,6 +271,7 @@ def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">; | |||||
def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; | def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; | ||||
def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>; | def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>; | ||||
def TQT: MgbHashableOp<"TQT", [TQTParam]>; | def TQT: MgbHashableOp<"TQT", [TQTParam]>; | ||||
def LSQ: MgbHashableOp<"LSQ", [LSQParam]>; | |||||
def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { | def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { | ||||
let extraArguments = (ins | let extraArguments = (ins | ||||
MgbDTypeAttr:$dtype | MgbDTypeAttr:$dtype | ||||
@@ -324,5 +324,7 @@ decl_opr('FakeQuant', | |||||
decl_opr('TQT', | decl_opr('TQT', | ||||
inputs=[Doc('src','input tensor'),Doc('scale','scale tensor')], | inputs=[Doc('src','input tensor'),Doc('scale','scale tensor')], | ||||
params='TQT') | params='TQT') | ||||
decl_opr('LSQ', | |||||
inputs=[Doc('src','input tensor'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor'),Doc('grad_scale','grad scale tensor')], | |||||
params='LSQ') | |||||
# vim: ft=python | # vim: ft=python |
@@ -6,20 +6,22 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||||
#include "megbrain/opr/dnn/batch_norm.h" | #include "megbrain/opr/dnn/batch_norm.h" | ||||
#include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
#include "megbrain/opr/dnn/correlation.h" | #include "megbrain/opr/dnn/correlation.h" | ||||
#include "megbrain/opr/dnn/fake_quant.h" | |||||
#include "megbrain/opr/dnn/images2neibs.h" | #include "megbrain/opr/dnn/images2neibs.h" | ||||
#include "megbrain/opr/dnn/pooling.h" | |||||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||||
#include "megbrain/opr/dnn/roi_pooling.h" | |||||
#include "megbrain/opr/dnn/roi_align.h" | |||||
#include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
#include "megbrain/opr/dnn/lrn.h" | #include "megbrain/opr/dnn/lrn.h" | ||||
#include "megbrain/opr/dnn/fake_quant.h" | |||||
#include "megbrain/opr/dnn/lsq.h" | |||||
#include "megbrain/opr/dnn/pooling.h" | |||||
#include "megbrain/opr/dnn/roi_align.h" | |||||
#include "megbrain/opr/dnn/roi_pooling.h" | |||||
#include "megbrain/opr/dnn/tqt.h" | #include "megbrain/opr/dnn/tqt.h" | ||||
#include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
@@ -183,7 +185,8 @@ struct ConvLoadDumpImpl { | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | ||||
auto&& opr = opr_.cast_final_safe<Opr>(); | auto&& opr = opr_.cast_final_safe<Opr>(); | ||||
ctx.write_param<ConvParam>(opr.param()); | ctx.write_param<ConvParam>(opr.param()); | ||||
ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy_transient()); | |||||
ctx.write_param<megdnn::param::ExecutionPolicy>( | |||||
opr.execution_policy_transient()); | |||||
} | } | ||||
static VarNode* make(const cg::VarNodeArray& inputs, const ConvParam& param, | static VarNode* make(const cg::VarNodeArray& inputs, const ConvParam& param, | ||||
@@ -252,6 +255,20 @@ struct OprMaker<opr::TQTBackward, 3> { | |||||
}; | }; | ||||
template <> | template <> | ||||
struct OprMaker<opr::LSQBackward, 5> { | |||||
using Param = opr::LSQBackward::Param; | |||||
static cg::OperatorNodeBase* make(const Param& param, | |||||
const cg::VarNodeArray& i, | |||||
ComputingGraph& graph, | |||||
const OperatorNodeConfig& config) { | |||||
MGB_MARK_USED_VAR(graph); | |||||
return opr::LSQBackward::make(i[0], i[1], i[2], i[3], i[4], param, | |||||
config)[0] | |||||
.node() | |||||
->owner_opr(); | |||||
} | |||||
}; | |||||
template <> | |||||
struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0> | struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0> | ||||
: public PoolingLoadDumpImpl<opr::AdaptivePoolingBackward, | : public PoolingLoadDumpImpl<opr::AdaptivePoolingBackward, | ||||
MakeAdaptivePoolingBackwardCaller3< | MakeAdaptivePoolingBackwardCaller3< | ||||
@@ -587,6 +604,8 @@ MGB_SEREG_OPR(FakeQuant, 3); | |||||
MGB_SEREG_OPR(FakeQuantBackward, 4); | MGB_SEREG_OPR(FakeQuantBackward, 4); | ||||
MGB_SEREG_OPR(TQT, 2); | MGB_SEREG_OPR(TQT, 2); | ||||
MGB_SEREG_OPR(TQTBackward, 3); | MGB_SEREG_OPR(TQTBackward, 3); | ||||
MGB_SEREG_OPR(LSQ, 4); | |||||
MGB_SEREG_OPR(LSQBackward, 5); | |||||
} // namespace opr | } // namespace opr | ||||
@@ -0,0 +1,90 @@ | |||||
/** | |||||
* \file src/opr/impl/dnn/lsq.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/opr/dnn/lsq.h" | |||||
#include "../internal/megdnn_opr_wrapper.inl" | |||||
#include "megbrain/graph/grad_impl.h" | |||||
#include "megbrain/opr/basic_arith_wrapper.h" | |||||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megbrain/opr/utility.h" | |||||
using namespace mgb; | |||||
using namespace opr; | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSQForward); | |||||
MEGDNN_OPR_INIT4(LSQForward, "lsq_fwd"); | |||||
#ifdef MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(LSQForward) { | |||||
SymbolVarArray grad = | |||||
LSQBackward::make(out_grad[0], opr.input(0), opr.input(1), | |||||
opr.input(2), opr.input(3), opr.param()); | |||||
if (wrt_idx == 0) { | |||||
return grad[0].node(); | |||||
} else if (wrt_idx == 1) { | |||||
return reduce_sum(grad[1], GetVarShape::make(opr.input(wrt_idx))) | |||||
.node(); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
#endif | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSQBackward); | |||||
LSQBackward::LSQBackward(VarNode* y_grad, VarNode* x, VarNode* scale, | |||||
VarNode* zero_point, VarNode* grad_scale, | |||||
const Param& param, const OperatorNodeConfig& config) | |||||
: Super({x->owner_graph(), | |||||
config, | |||||
"lsq_bwd", | |||||
{y_grad, x, scale, zero_point, grad_scale}}, | |||||
1, true) { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({y_grad, x, scale, zero_point, grad_scale}); | |||||
} | |||||
SymbolVarArray LSQBackward::make(SymbolVar y_grad, SymbolVar x, SymbolVar scale, | |||||
SymbolVar zero_point, SymbolVar grad_scale, | |||||
const Param& param, | |||||
const OperatorNodeConfig& config) { | |||||
auto&& out = x.node()->owner_graph() | |||||
->insert_opr(std::make_unique<LSQBackward>( | |||||
y_grad.node(), x.node(), scale.node(), | |||||
zero_point.node(), grad_scale.node(), param, | |||||
config)) | |||||
->output(); | |||||
SymbolVarArray ret(out.size()); | |||||
for (size_t i = 0; i < ret.size(); ++i) { | |||||
ret[i] = out[i]; | |||||
} | |||||
return ret; | |||||
} | |||||
void LSQBackward::init_output_static_infer_desc() { | |||||
using namespace cg::static_infer; | |||||
auto&& mgr = owner_graph()->static_infer_manager(); | |||||
mgr.register_shape_infer(output(0), | |||||
ShapeInferDesc::make_identity(input(1))); | |||||
mgr.register_shape_infer(output(1), | |||||
ShapeInferDesc::make_identity(input(1))); | |||||
this->init_output_static_infer_desc_workspace( | |||||
intl::AutoAddWorkspaceNeedLimitGetter<megdnn::LSQBackward>::val); | |||||
} | |||||
void LSQBackward::init_output_dtype() { | |||||
output(0)->dtype(input(1)->dtype()); | |||||
output(1)->dtype(input(2)->dtype()); | |||||
} |
@@ -0,0 +1,50 @@ | |||||
/** | |||||
* \file src/opr/include/megbrain/opr/dnn/lsq.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
#include "megdnn/oprs.h" | |||||
namespace mgb { | |||||
namespace opr { | |||||
MGB_DEFINE_OPR_CLASS(LSQForward, | |||||
intl::MegDNNOprWrapperFwd<megdnn::LSQForward>) // { | |||||
public: | |||||
LSQForward(VarNode* src, VarNode* scale, VarNode* zero_point, | |||||
VarNode* grad_scale, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVar make(SymbolVar src, SymbolVar scale, SymbolVar zero_point, | |||||
SymbolVar grad_scale, const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
}; | |||||
using LSQ = LSQForward; | |||||
MGB_DEFINE_OPR_CLASS(LSQBackward, | |||||
intl::MegDNNOprWrapperBwd<megdnn::LSQBackward>) // { | |||||
public: | |||||
LSQBackward(VarNode* y_grad, VarNode* x, VarNode* scale, VarNode* zero_point, | |||||
VarNode* grad_scale, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVarArray make(SymbolVar y_grad, SymbolVar x, SymbolVar scale, | |||||
SymbolVar zero_point, SymbolVar grad_scale, | |||||
const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
private: | |||||
void init_output_static_infer_desc() override; | |||||
void init_output_dtype() override; | |||||
}; | |||||
} // namespace opr | |||||
} // namespace mgb |
@@ -0,0 +1,78 @@ | |||||
/** | |||||
* \file src/opr/test/dnn/lsq.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/opr/dnn/lsq.h" | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/test/autocheck.h" | |||||
using namespace std; | |||||
using namespace mgb; | |||||
namespace { | |||||
void run() { | |||||
using Checker = AutoOprChecker<4, 1>; | |||||
auto make_graph = | |||||
[&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||||
auto o0 = opr::LSQForward::make(inputs[0], inputs[1], inputs[2], | |||||
inputs[3]); | |||||
return {o0}; | |||||
}; | |||||
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||||
auto opr = MegDNNHandle::get( | |||||
CompNodeEnv::from_comp_node(CompNode::default_cpu())) | |||||
->create_operator<megdnn::LSQForward>(); | |||||
dest[0].dtype(dtype::Float32()) | |||||
.comp_node(inp[0]->comp_node()) | |||||
.resize(inp[0]->shape()); | |||||
opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), | |||||
inp[3]->as_megdnn(), dest[0].as_megdnn(), {}); | |||||
}; | |||||
auto gen = [&](HostTensorND& src) { | |||||
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> | |||||
src_gen(10.f); | |||||
src = *src_gen(src.shape(), src.comp_node()); | |||||
}; | |||||
Checker::RunOptions opt; | |||||
opt.numdiff_max_err = 1e-5; | |||||
Checker checker{make_graph, fwd}; | |||||
checker.set_input_generator(0, gen) | |||||
.set_input_generator(1, gen) | |||||
.set_input_generator(2, gen) | |||||
.set_input_generator(3, gen) | |||||
.set_input_allow_grad(0, false) | |||||
.set_input_allow_grad(1, false) | |||||
.set_input_allow_grad(2, false) | |||||
.set_input_allow_grad(3, false) | |||||
.set_output_allow_grad(0, false); | |||||
checker.run({TensorShape{1, 2, 3, 4}, TensorShape{1}, TensorShape{1}, | |||||
TensorShape{1}}, | |||||
opt) | |||||
.run({TensorShape{2, 3, 8, 8}, TensorShape{1}, TensorShape{1}, | |||||
TensorShape{1}}, | |||||
opt) | |||||
.run({TensorShape{1, 3, 4, 4}, TensorShape{1}, TensorShape{1}, | |||||
TensorShape{1}}, | |||||
opt); | |||||
} | |||||
} // anonymous namespace | |||||
TEST(TestOprDNN, LSQForward) { | |||||
REQUIRE_GPU(1); | |||||
run(); | |||||
} |
@@ -107,6 +107,7 @@ union OperatorParam { | |||||
param.FakeQuant = 73, | param.FakeQuant = 73, | ||||
param.TQT = 74, | param.TQT = 74, | ||||
param.Correlation = 75, | param.Correlation = 75, | ||||
param.LSQ = 76, | |||||
} | } | ||||
table Operator { | table Operator { | ||||