@@ -2048,6 +2048,53 @@ protected: | |||||
const TensorLayout& doup, const TensorLayout& mask, | const TensorLayout& doup, const TensorLayout& mask, | ||||
const TensorLayout& dinp, size_t workspace_in_bytes); | const TensorLayout& dinp, size_t workspace_in_bytes); | ||||
}; | }; | ||||
class SoftmaxBase : public OperatorBase { | |||||
DEF_OPR_IMPL_CTOR(SoftmaxBase, OperatorBase); | |||||
DEF_OPR_PARAM(Softmax); | |||||
protected: | |||||
void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output); | |||||
void check_layout_fwd(const TensorLayout& input, const TensorLayout& output); | |||||
}; | |||||
class SoftmaxForward : public SoftmaxBase { | |||||
DEF_OPR_IMPL(SoftmaxForward, SoftmaxBase, 1, 1); | |||||
public: | |||||
/** | |||||
* \param[in] input input tensor | |||||
* \param[out] output output tensor | |||||
*/ | |||||
virtual void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_out output, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& input, TensorLayout& output); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& output) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& input, const TensorLayout& output, | |||||
size_t workspace_in_bytes); | |||||
}; | |||||
using Softmax = SoftmaxForward; | |||||
class SoftmaxBackward : public SoftmaxBase { | |||||
DEF_OPR_IMPL(SoftmaxBackward, SoftmaxBase, 2, 1); | |||||
public: | |||||
virtual void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in diff, _megdnn_tensor_out grad_x, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& diff, | |||||
const TensorLayout& grad_x) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& input, const TensorLayout& diff, | |||||
const TensorLayout& grad_x, size_t workspace_in_bytes); | |||||
}; | |||||
class RNNCellForward : public OperatorBase { | class RNNCellForward : public OperatorBase { | ||||
DEF_OPR_PARAM(RNNCell); | DEF_OPR_PARAM(RNNCell); | ||||
@@ -253,6 +253,10 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
add_enum_alias('Format', 'Convolution') | add_enum_alias('Format', 'Convolution') | ||||
) | ) | ||||
(pdef('Softmax'). | |||||
add_fields('int32', 'axis', -1) | |||||
) | |||||
(pdef('AdaptivePooling', version=0, is_legacy=True). | (pdef('AdaptivePooling', version=0, is_legacy=True). | ||||
add_enum_alias('Mode', 'PoolingV0'). | add_enum_alias('Mode', 'PoolingV0'). | ||||
add_enum_alias('Format', 'ConvolutionV0') | add_enum_alias('Format', 'ConvolutionV0') | ||||
@@ -219,7 +219,9 @@ private: | |||||
cb(RNN) \ | cb(RNN) \ | ||||
cb(RNNBackward) \ | cb(RNNBackward) \ | ||||
cb(LSTM) \ | cb(LSTM) \ | ||||
cb(LSTMBackward) | |||||
cb(LSTMBackward) \ | |||||
cb(SoftmaxForward) \ | |||||
cb(SoftmaxBackward) | |||||
// clang-format on | // clang-format on | ||||
/*! | /*! | ||||
@@ -145,6 +145,8 @@ DEF(RNNBackward, 10, true, true); | |||||
DEF(LSTMCellForward, 10, true, true); | DEF(LSTMCellForward, 10, true, true); | ||||
DEF(LSTMForward, 8, true, true); | DEF(LSTMForward, 8, true, true); | ||||
DEF(LSTMBackward, 13, true, true); | DEF(LSTMBackward, 13, true, true); | ||||
DEF(SoftmaxForward, 2, true, true); | |||||
DEF(SoftmaxBackward, 3, true, false); | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -0,0 +1,61 @@ | |||||
/** | |||||
* \file dnn/src/common/softmax.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
void SoftmaxBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { | |||||
megdnn_assert( | |||||
param().axis >= -static_cast<int32_t>(src.ndim) && | |||||
param().axis < static_cast<int32_t>(src.ndim), | |||||
"axis: %d ndim: %zu", param().axis, src.ndim); | |||||
megdnn_assert_contiguous(src); | |||||
dst = src; | |||||
dst.dtype = src.dtype; | |||||
dst.format = src.format; | |||||
dst.init_contiguous_stride(); | |||||
} | |||||
void SoftmaxBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) { | |||||
TensorLayout dst_expected; | |||||
megdnn_assert_eq_dtype(src, dst); | |||||
deduce_layout_fwd(src, dst_expected); | |||||
megdnn_assert_eq_layout(dst_expected, dst); | |||||
megdnn_assert(src.dtype == dst.dtype); | |||||
} | |||||
void SoftmaxForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { | |||||
deduce_layout_fwd(src, dst); | |||||
} | |||||
void SoftmaxForward::check_exec( | |||||
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { | |||||
check_layout_fwd(src, dst); | |||||
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
} | |||||
void SoftmaxBackward::check_exec( | |||||
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, | |||||
size_t workspace_in_bytes) { | |||||
megdnn_assert_eq_layout(src, diff); | |||||
megdnn_assert_eq_layout(src, grad); | |||||
auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
} | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -76,6 +76,7 @@ | |||||
#include "src/cuda/separable_filter/opr_impl.h" | #include "src/cuda/separable_filter/opr_impl.h" | ||||
#include "src/cuda/sleep/opr_impl.h" | #include "src/cuda/sleep/opr_impl.h" | ||||
#include "src/cuda/sliding_window_transpose/opr_impl.h" | #include "src/cuda/sliding_window_transpose/opr_impl.h" | ||||
#include "src/cuda/softmax/opr_impl.h" | |||||
#include "src/cuda/split/opr_impl.h" | #include "src/cuda/split/opr_impl.h" | ||||
#include "src/cuda/svd/opr_impl.h" | #include "src/cuda/svd/opr_impl.h" | ||||
#include "src/cuda/tensor_remap/opr_impl.h" | #include "src/cuda/tensor_remap/opr_impl.h" | ||||
@@ -221,6 +222,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormBackward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormBackward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxBackward); | |||||
template <typename Opr> | template <typename Opr> | ||||
std::unique_ptr<Opr> HandleImpl::create_operator() { | std::unique_ptr<Opr> HandleImpl::create_operator() { | ||||
@@ -0,0 +1,174 @@ | |||||
/** | |||||
* \file dnn/src/cuda/softmax/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/cuda/softmax/opr_impl.h" | |||||
#include "src/cuda/handle.h" | |||||
#include "src/cuda/utils.h" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
int CanonicalAxis(const int axis, const int rank) { | |||||
if (axis < 0) { | |||||
return axis + rank; | |||||
} | |||||
return axis; | |||||
} | |||||
int SizeToAxis(const int axis, const size_t* dims) { | |||||
int size = 1; | |||||
for (int i = 0; i < axis; i++) { | |||||
size *= dims[i]; | |||||
} | |||||
return size; | |||||
} | |||||
int SizeOutAxis(const int axis, const size_t* dims, const int ndim) { | |||||
int size = 1; | |||||
for (int i = axis + 1; i < ndim; i++) { | |||||
size *= dims[i]; | |||||
} | |||||
return size; | |||||
} | |||||
std::vector<int> SoftmaxForwardImpl::init_mode( | |||||
_megdnn_tensor_in src, cudnnSoftmaxMode_t& mode) const { | |||||
auto dims = src.layout.shape; | |||||
const int rank = src.layout.ndim; | |||||
const int axis = CanonicalAxis(param().axis, rank); | |||||
const int dim = dims[axis]; | |||||
const int N = SizeToAxis(axis, dims); | |||||
const int D = SizeOutAxis(axis, dims, rank); | |||||
mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE : CUDNN_SOFTMAX_MODE_CHANNEL; | |||||
return {N, dim, D, 1}; | |||||
} | |||||
int sc(const size_t x) { | |||||
return static_cast<int>(x); | |||||
} | |||||
cudnnDataType_t to_cudnn_dtype( | |||||
DType type, const param::Convolution::Format format = {}) { | |||||
switch (type.enumv()) { | |||||
case DTypeEnum::Float32: | |||||
return CUDNN_DATA_FLOAT; | |||||
case DTypeEnum::Float16: | |||||
return CUDNN_DATA_HALF; | |||||
#if CUDNN_MAJOR >= 7 | |||||
case DTypeEnum::Int32: | |||||
case DTypeEnum::QuantizedS32: | |||||
return CUDNN_DATA_INT32; | |||||
#endif | |||||
#if CUDNN_MAJOR >= 6 | |||||
case DTypeEnum::QuantizedS8: { | |||||
if (format == param::Convolution::Format::NCHW4) | |||||
return CUDNN_DATA_INT8x4; | |||||
#if CUDNN_VERSION >= 7500 | |||||
else if (format == param::Convolution::Format::NCHW32) | |||||
return CUDNN_DATA_INT8x32; | |||||
#endif | |||||
else | |||||
return CUDNN_DATA_INT8; | |||||
} | |||||
case DTypeEnum::Int8: { | |||||
if (format == param::Convolution::Format::NCHW4) | |||||
return CUDNN_DATA_INT8x4; | |||||
#if CUDNN_VERSION >= 7500 | |||||
else if (format == param::Convolution::Format::NCHW32) | |||||
return CUDNN_DATA_INT8x32; | |||||
#endif | |||||
else | |||||
return CUDNN_DATA_INT8; | |||||
} | |||||
#endif | |||||
default: | |||||
#if CUDNN_MAJOR >= 6 | |||||
megdnn_throw("dtype must be float16/float32/int8/int32"); | |||||
#else | |||||
megdnn_throw("dtype must be float16/float32"); | |||||
#endif | |||||
} | |||||
} | |||||
void SoftmaxForwardImpl::exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
dt_float32 alpha = 1.0f, beta = 0.0f; | |||||
TensorDesc src_desc, dst_desc; | |||||
cudnnSoftmaxMode_t mode; | |||||
std::vector<int> tensor_dims = init_mode(src, mode); | |||||
const int dimA[] = { | |||||
sc(tensor_dims[0]), sc(tensor_dims[1]), sc(tensor_dims[2]), | |||||
sc(tensor_dims[3])}; | |||||
const int strideA[] = { | |||||
sc(tensor_dims[1] * tensor_dims[2] * tensor_dims[3]), | |||||
sc(tensor_dims[2] * tensor_dims[3]), sc(tensor_dims[3]), 1}; | |||||
cudnn_check(cudnnSetTensorNdDescriptor( | |||||
src_desc.desc, to_cudnn_dtype(src.layout.dtype), 4, dimA, strideA)); | |||||
cudnn_check(cudnnSetTensorNdDescriptor( | |||||
dst_desc.desc, to_cudnn_dtype(dst.layout.dtype), 4, dimA, strideA)); | |||||
cudnn_check(cudnnSoftmaxForward( | |||||
cudnn_handle(this->handle()), CUDNN_SOFTMAX_ACCURATE, mode, &alpha, | |||||
src_desc.desc, src.raw_ptr(), &beta, dst_desc.desc, dst.raw_ptr())); | |||||
} | |||||
//================================Softmax Backward============================ | |||||
std::vector<int> SoftmaxBackwardImpl::init_mode( | |||||
_megdnn_tensor_in src, cudnnSoftmaxMode_t& mode) const { | |||||
auto dims = src.layout.shape; | |||||
const int rank = src.layout.ndim; | |||||
const int axis = CanonicalAxis(param().axis, rank); | |||||
const int dim = dims[axis]; | |||||
const int N = SizeToAxis(axis, dims); | |||||
const int D = SizeOutAxis(axis, dims, rank); | |||||
mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE : CUDNN_SOFTMAX_MODE_CHANNEL; | |||||
return {N, dim, D, 1}; | |||||
} | |||||
void SoftmaxBackwardImpl::exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) { | |||||
{ | |||||
dt_float32 alpha = 1.0f, beta = 0.0f; | |||||
TensorDesc src_desc, diff_desc, grad_desc; | |||||
cudnnSoftmaxMode_t mode; | |||||
std::vector<int> tensor_dims = init_mode(src, mode); | |||||
const int dimA[] = { | |||||
sc(tensor_dims[0]), sc(tensor_dims[1]), sc(tensor_dims[2]), | |||||
sc(tensor_dims[3])}; | |||||
const int strideA[] = { | |||||
sc(tensor_dims[1] * tensor_dims[2] * tensor_dims[3]), | |||||
sc(tensor_dims[2] * tensor_dims[3]), sc(tensor_dims[3]), 1}; | |||||
cudnn_check(cudnnSetTensorNdDescriptor( | |||||
src_desc.desc, to_cudnn_dtype(src.layout.dtype), 4, dimA, strideA)); | |||||
cudnn_check(cudnnSetTensorNdDescriptor( | |||||
diff_desc.desc, to_cudnn_dtype(diff.layout.dtype), 4, dimA, strideA)); | |||||
cudnn_check(cudnnSetTensorNdDescriptor( | |||||
grad_desc.desc, to_cudnn_dtype(grad.layout.dtype), 4, dimA, strideA)); | |||||
cudnn_check(cudnnSoftmaxBackward( | |||||
cudnn_handle(this->handle()), CUDNN_SOFTMAX_ACCURATE, mode, &alpha, | |||||
src_desc.desc, src.raw_ptr(), diff_desc.desc, diff.raw_ptr(), &beta, | |||||
grad_desc.desc, grad.raw_ptr())); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,58 @@ | |||||
/** | |||||
* \file dnn/src/cuda/softmax/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/algo_base.h" | |||||
#include "src/common/metahelper.h" | |||||
#include "src/cuda/cudnn_wrapper.h" | |||||
#include "src/cuda/utils.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
class SoftmaxForwardImpl final : public SoftmaxForward { | |||||
public: | |||||
using SoftmaxForward::SoftmaxForward; | |||||
std::vector<int> init_mode(_megdnn_tensor_in src, cudnnSoftmaxMode_t& mode) const; | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout&, /* src */ | |||||
const TensorLayout& /* dst */) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
class SoftmaxBackwardImpl final : public SoftmaxBackward { | |||||
public: | |||||
using SoftmaxBackward::SoftmaxBackward; | |||||
std::vector<int> init_mode(_megdnn_tensor_in src, cudnnSoftmaxMode_t& mode) const; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& /* input */, const TensorLayout& /* diff */, | |||||
const TensorLayout& /* grad_x */) override { | |||||
return 0; | |||||
} | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) override; | |||||
}; | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -81,6 +81,7 @@ | |||||
#include "src/naive/separable_filter/opr_impl.h" | #include "src/naive/separable_filter/opr_impl.h" | ||||
#include "src/naive/sleep/opr_impl.h" | #include "src/naive/sleep/opr_impl.h" | ||||
#include "src/naive/sliding_window_transpose/opr_impl.h" | #include "src/naive/sliding_window_transpose/opr_impl.h" | ||||
#include "src/naive/softmax/opr_impl.h" | |||||
#include "src/naive/split/opr_impl.h" | #include "src/naive/split/opr_impl.h" | ||||
#include "src/naive/svd/opr_impl.h" | #include "src/naive/svd/opr_impl.h" | ||||
#include "src/naive/tensor_remap/opr_impl.h" | #include "src/naive/tensor_remap/opr_impl.h" | ||||
@@ -0,0 +1,116 @@ | |||||
/** | |||||
* \file dnn/src/naive/softmax/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/naive/softmax/opr_impl.h" | |||||
#include <cstring> | |||||
#include "megdnn/dtype.h" | |||||
#include "megdnn/tensor_iter.h" | |||||
#include "src/common/elemwise_helper.cuh" | |||||
#include "src/common/opr_delegate.h" | |||||
#include "src/common/reduce_helper.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/naive/elemwise/opr_impl.h" | |||||
#include "src/naive/handle.h" | |||||
#include "src/naive/lowbit_utils.h" | |||||
using namespace megdnn; | |||||
namespace { | |||||
template <typename T> | |||||
TensorND op_exec(_megdnn_tensor_in src, megdnn::dt_byte* workspace_ptr, const T& opr) { | |||||
TensorLayout dst_layout; | |||||
opr->deduce_layout(src.layout, dst_layout); | |||||
TensorND dst{workspace_ptr, dst_layout}; | |||||
workspace_ptr += dst_layout.span().dist_byte(); | |||||
auto new_workspace = Workspace{ | |||||
workspace_ptr, opr->get_workspace_in_bytes(src.layout, dst_layout)}; | |||||
workspace_ptr += opr->get_workspace_in_bytes(src.layout, dst_layout); | |||||
opr->exec(src, dst, new_workspace); | |||||
return dst; | |||||
} | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace naive { | |||||
//===============================Softmax Forward============================ | |||||
void SoftmaxForwardImpl::exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
auto axis = param().axis; | |||||
if (axis < 0) | |||||
axis += src.layout.ndim; | |||||
check_exec(src.layout, dst.layout, workspace.size); | |||||
auto workspace_ptr = workspace.raw_ptr; | |||||
auto reduce_opr = handle()->create_operator<ReduceForward>(); | |||||
reduce_opr->param().axis = axis; | |||||
reduce_opr->param().mode = Reduce::Mode::MAX; | |||||
reduce_opr->param().data_type = param::Reduce::DataType::DEFAULT; | |||||
TensorND max_tensor = op_exec(src, workspace_ptr, reduce_opr); | |||||
auto elemwise_opr = handle()->create_operator<Elemwise>(); | |||||
elemwise_opr->param().mode = Elemwise::Mode::SUB; | |||||
elemwise_opr->exec({src, max_tensor}, dst); | |||||
elemwise_opr->param().mode = Elemwise::Mode::EXP; | |||||
TensorLayout exp_layout; | |||||
elemwise_opr->deduce_layout({src.layout}, exp_layout); | |||||
TensorND exp_tensor{workspace_ptr, exp_layout}; | |||||
workspace_ptr += exp_layout.span().dist_byte(); | |||||
elemwise_opr->exec({dst}, exp_tensor); | |||||
reduce_opr->param().mode = Reduce::Mode::SUM; | |||||
TensorND down_tensor = op_exec(exp_tensor, workspace_ptr, reduce_opr); | |||||
elemwise_opr->param().mode = Elemwise::Mode::TRUE_DIV; | |||||
elemwise_opr->exec({exp_tensor, down_tensor}, dst); | |||||
} | |||||
//=============================Softmax backward ============================ | |||||
void SoftmaxBackwardImpl::exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) { | |||||
auto axis = param().axis; | |||||
if (axis < 0) | |||||
axis += src.layout.ndim; | |||||
check_exec(src.layout, diff.layout, grad.layout, workspace.size); | |||||
auto workspace_ptr = workspace.raw_ptr; | |||||
TensorLayout mulres = src.layout; | |||||
mulres.dtype = src.layout.dtype; | |||||
mulres.format = src.layout.format; | |||||
mulres.init_contiguous_stride(); | |||||
TensorND mul_tensor{workspace_ptr, mulres}; | |||||
workspace_ptr += mulres.span().dist_byte(); | |||||
TensorND mul_tensor2{workspace_ptr, mulres}; | |||||
workspace_ptr += mulres.span().dist_byte(); | |||||
auto elemwise_opr = handle()->create_operator<Elemwise>(); | |||||
elemwise_opr->param().mode = Elemwise::Mode::MUL; | |||||
elemwise_opr->exec({src, diff}, mul_tensor); | |||||
auto reduce_opr = handle()->create_operator<ReduceForward>(); | |||||
reduce_opr->param().axis = axis; | |||||
reduce_opr->param().mode = Reduce::Mode::SUM; | |||||
reduce_opr->param().data_type = param::Reduce::DataType::DEFAULT; | |||||
TensorND sum_tensor = op_exec(mul_tensor, workspace_ptr, reduce_opr); | |||||
elemwise_opr->exec({sum_tensor, src}, mul_tensor2); | |||||
elemwise_opr->param().mode = Elemwise::Mode::SUB; | |||||
elemwise_opr->exec({mul_tensor, mul_tensor2}, grad); | |||||
} | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,45 @@ | |||||
/** | |||||
* \file dnn/src/naive/softmax/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
class SoftmaxForwardImpl final : public SoftmaxForward { | |||||
public: | |||||
using SoftmaxForward::SoftmaxForward; | |||||
void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout&) override { | |||||
return src.span().dist_byte() * 2; | |||||
} | |||||
}; | |||||
class SoftmaxBackwardImpl final : public SoftmaxBackward { | |||||
public: | |||||
using SoftmaxBackward::SoftmaxBackward; | |||||
void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad_x, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return src.span().dist_byte() * 3; | |||||
} | |||||
}; | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,41 @@ | |||||
/** | |||||
* \file dnn/test/common/softmax.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cstddef> | |||||
#include "megdnn/basic_types.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
namespace softmax { | |||||
struct TestArg { | |||||
param::Softmax param; | |||||
TensorShape ishape; | |||||
TestArg(param::Softmax param, TensorShape ishape) : param(param), ishape(ishape) {} | |||||
}; | |||||
inline std::vector<TestArg> get_args() { | |||||
std::vector<TestArg> args; | |||||
using Param = param::Softmax; | |||||
for (int32_t axis = 0; axis < 5; axis++) { | |||||
args.emplace_back(Param{axis}, TensorShape{2, 23, 32, 30, 17}); | |||||
} | |||||
return args; | |||||
} | |||||
} // namespace softmax | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,71 @@ | |||||
/** | |||||
* \file dnn/test/cuda/softmax.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "test/cuda/fixture.h" | |||||
#include "megdnn/tensor_iter.h" | |||||
#include "test/common/checker.h" | |||||
#include "test/common/softmax.h" | |||||
#include "src/common/utils.h" | |||||
#include "test/cuda/utils.h" | |||||
// to check cudnn version | |||||
#include <cudnn.h> | |||||
#include "test/cuda/benchmark.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
TEST_F(CUDA, SOFTMAX_FORWARD) { | |||||
auto args = softmax::get_args(); | |||||
std::vector<DType> dtypes{dtype::Float16(), dtype::Float32()}; | |||||
for (auto dtype : dtypes) | |||||
for (auto&& arg : args) { | |||||
auto param = arg.param; | |||||
auto src = arg.ishape; | |||||
Checker<Softmax> checker(handle_cuda()); | |||||
if (dtype == dtype::BFloat16()) { | |||||
checker.set_epsilon(2e-2); | |||||
} else { | |||||
checker.set_epsilon(1e-2); | |||||
} | |||||
checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec( | |||||
TensorShapeArray{src, {}}); | |||||
} | |||||
} | |||||
TEST_F(CUDA, SOFTMAX_BACKWARD) { | |||||
auto args = softmax::get_args(); | |||||
for (auto&& arg : args) { | |||||
Checker<SoftmaxBackward> checker(handle_cuda()); | |||||
TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32()); | |||||
TensorLayout olayout; | |||||
{ | |||||
auto opr = handle_cuda()->create_operator<SoftmaxForward>(); | |||||
opr->param() = arg.param; | |||||
opr->deduce_layout(ilayout, olayout); | |||||
} | |||||
auto set_dtype = [&checker](DType dtype) { | |||||
checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype); | |||||
}; | |||||
set_dtype(dtype::Float32()); | |||||
checker.set_epsilon(1e-3).set_param(arg.param).exec( | |||||
TensorShapeArray{ilayout, olayout, ilayout}); | |||||
} | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,56 @@ | |||||
/** | |||||
* \file dnn/test/naive/softmax.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "test/naive/fixture.h" | |||||
#include "megdnn/oprs/nn.h" | |||||
#include "test/common/checker.h" | |||||
using namespace megdnn; | |||||
using namespace test; | |||||
TEST_F(NAIVE, SOFTMAX_FORWARD) { | |||||
Checker<Softmax> checker(handle(), /* check_dispatch */ false); | |||||
Softmax::Param param{0}; | |||||
TensorND input = TensorValue( | |||||
{2, 2, 2, 2}, dtype::Float32(), | |||||
{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); | |||||
TensorND output = TensorValue( | |||||
{2, 2, 2, 2}, dtype::Float32(), | |||||
{0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.9997, | |||||
0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997}); | |||||
checker.set_param(param).exect(Testcase{input, {}}, Testcase{{}, output}); | |||||
} | |||||
TEST_F(NAIVE, SOFTMAX_BACKWARD) { | |||||
Checker<SoftmaxBackward> checker(handle(), /* check_dispatch */ false); | |||||
Softmax::Param param{0}; | |||||
TensorND input = TensorValue( | |||||
{2, 2, 2, 2}, dtype::Float32(), | |||||
{0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.9997, | |||||
0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997}); | |||||
TensorND diff = TensorValue( | |||||
{2, 2, 2, 2}, dtype::Float32(), | |||||
{1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); | |||||
TensorND output = TensorValue( | |||||
{2, 2, 2, 2}, dtype::Float32(), | |||||
{0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); | |||||
checker.set_param(param).exect(Testcase{input, diff, {}}, Testcase{{}, {}, output}); | |||||
} |
@@ -1061,10 +1061,15 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: | |||||
""" | """ | ||||
if axis is None: | if axis is None: | ||||
axis = _get_softmax_axis(len(inp.shape)) | axis = _get_softmax_axis(len(inp.shape)) | ||||
offset = inp.max(axis=axis, keepdims=True).detach() | |||||
cached = exp(inp - offset) | |||||
down = sum(cached, axis=axis, keepdims=True) | |||||
return cached / down | |||||
if isinstance(axis, list): | |||||
offset = inp.max(axis=axis, keepdims=True).detach() | |||||
cached = exp(inp - offset) | |||||
down = sum(cached, axis=axis, keepdims=True) | |||||
return cached / down | |||||
else: | |||||
op = builtin.Softmax(axis=axis,) | |||||
(output,) = apply(op, inp) | |||||
return output | |||||
def layer_norm( | def layer_norm( | ||||
@@ -0,0 +1,52 @@ | |||||
/** | |||||
* \file imperative/src/impl/ops/softmax.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/opr/dnn/softmax.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "../dnn_op_helper.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
namespace { | |||||
namespace softmax { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& softmax = static_cast<const Softmax&>(def); | |||||
OperatorNodeConfig config{softmax.make_name()}; | |||||
return opr::Softmax::make(inputs[0], softmax.param(), config); | |||||
} | |||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
auto* node = &node_->cast_final_safe<opr::Softmax>(); | |||||
return Softmax::make(node->param()); | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
const OpDef&, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
SmallVector<LogicalTensorDesc> out_shapes(1); | |||||
auto&& i0 = inputs[0]; | |||||
out_shapes[0] = {i0.layout, i0.comp_node}; | |||||
return {out_shapes, true}; | |||||
} | |||||
OP_TRAIT_REG(Softmax, Softmax, opr::Softmax) | |||||
.make_from_op_node(make_from_op_node) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
.fallback(); | |||||
} // namespace softmax | |||||
} // namespace | |||||
} // namespace imperative | |||||
} // namespace mgb | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -354,6 +354,7 @@ def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; | |||||
def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>; | def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>; | ||||
def TQT: MgbHashableOp<"TQT", [TQTParam]>; | def TQT: MgbHashableOp<"TQT", [TQTParam]>; | ||||
def LSQ: MgbHashableOp<"LSQ", [LSQParam]>; | def LSQ: MgbHashableOp<"LSQ", [LSQParam]>; | ||||
def Softmax: MgbHashableOp<"Softmax", [SoftmaxParam]>; | |||||
def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { | def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { | ||||
let extraArguments = (ins | let extraArguments = (ins | ||||
MgbDTypeAttr:$dtype | MgbDTypeAttr:$dtype | ||||
@@ -327,4 +327,7 @@ decl_opr('TQT', | |||||
decl_opr('LSQ', | decl_opr('LSQ', | ||||
inputs=[Doc('src','input tensor'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor'),Doc('grad_scale','grad scale tensor')], | inputs=[Doc('src','input tensor'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor'),Doc('grad_scale','grad scale tensor')], | ||||
params='LSQ') | params='LSQ') | ||||
decl_opr('Softmax', | |||||
inputs=[Doc('src','input tensor')], | |||||
params='Softmax') | |||||
# vim: ft=python | # vim: ft=python |
@@ -25,6 +25,7 @@ | |||||
#include "megbrain/opr/dnn/roi_align.h" | #include "megbrain/opr/dnn/roi_align.h" | ||||
#include "megbrain/opr/dnn/roi_pooling.h" | #include "megbrain/opr/dnn/roi_pooling.h" | ||||
#include "megbrain/opr/dnn/sliding_window_transpose.h" | #include "megbrain/opr/dnn/sliding_window_transpose.h" | ||||
#include "megbrain/opr/dnn/softmax.h" | |||||
#include "megbrain/opr/dnn/tqt.h" | #include "megbrain/opr/dnn/tqt.h" | ||||
#include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
@@ -325,6 +326,19 @@ struct OprMaker<opr::LSTMBackward, 9> { | |||||
}; | }; | ||||
template <> | template <> | ||||
struct OprMaker<opr::SoftmaxBackward, 2> { | |||||
using Param = opr::SoftmaxBackward::Param; | |||||
static cg::OperatorNodeBase* make( | |||||
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, | |||||
const OperatorNodeConfig& config) { | |||||
MGB_MARK_USED_VAR(graph); | |||||
return opr::SoftmaxBackward::make(i[0], i[1], param, config) | |||||
.node() | |||||
->owner_opr(); | |||||
} | |||||
}; | |||||
template <> | |||||
struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0> | struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0> | ||||
: public GeneralOprLoadDumpImpl< | : public GeneralOprLoadDumpImpl< | ||||
opr::AdaptivePoolingBackward, | opr::AdaptivePoolingBackward, | ||||
@@ -720,6 +734,8 @@ MGB_SEREG_OPR(RNNForward, 3); | |||||
MGB_SEREG_OPR(RNNBackward, 7); | MGB_SEREG_OPR(RNNBackward, 7); | ||||
MGB_SEREG_OPR(LSTMForward, 4); | MGB_SEREG_OPR(LSTMForward, 4); | ||||
MGB_SEREG_OPR(LSTMBackward, 9); | MGB_SEREG_OPR(LSTMBackward, 9); | ||||
MGB_SEREG_OPR(Softmax, 1); | |||||
MGB_SEREG_OPR(SoftmaxBackward, 2); | |||||
} // namespace opr | } // namespace opr | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -0,0 +1,124 @@ | |||||
/** | |||||
* \file src/opr/impl/dnn/softmax.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/opr/dnn/softmax.h" | |||||
#include "megbrain/graph/grad_impl.h" | |||||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||||
#include "megbrain/opr/utility.h" | |||||
#include "../internal/megdnn_opr_wrapper.inl" | |||||
using namespace mgb; | |||||
using namespace opr; | |||||
/* ==================== SoftmaxForward ==================== */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SoftmaxForward); | |||||
SoftmaxForward::SoftmaxForward( | |||||
VarNode* inp, const Param& param, const OperatorNodeConfig& config) | |||||
: Super{inp->owner_graph(), config, "softmax", {inp}} { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({inp}); | |||||
output(0)->dtype(inp->dtype()); | |||||
} | |||||
SymbolVar SoftmaxForward::make( | |||||
SymbolVar inp, const Param& param, const OperatorNodeConfig& config) { | |||||
auto out = inp.node() | |||||
->owner_graph() | |||||
->insert_opr(std::make_unique<SoftmaxForward>( | |||||
inp.node(), param, config)) | |||||
->output(); | |||||
return out[0]; | |||||
} | |||||
void SoftmaxForward::get_output_var_shape( | |||||
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { | |||||
out_shape[0] = inp_shape[0]; | |||||
} | |||||
size_t SoftmaxForward::get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const { | |||||
return megdnn_opr()->get_workspace_in_bytes( | |||||
{input_shapes[0], input(0)->dtype(), input(0)->format()}, | |||||
{output_shapes[0], output(0)->dtype(), output(0)->format()}); | |||||
} | |||||
void SoftmaxForward::scn_do_execute() { | |||||
megdnn_opr()->exec( | |||||
input(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||||
intl::get_megdnn_workspace_from_var(output().back())); | |||||
} | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(SoftmaxForward) { | |||||
SymbolVar grad = SoftmaxBackward::make(opr.output(0), out_grad[0], opr.param()); | |||||
return grad.node(); | |||||
} | |||||
#endif | |||||
// /* ==================== SoftmaxBackward ==================== */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SoftmaxBackward); | |||||
SoftmaxBackward::SoftmaxBackward( | |||||
VarNode* src, VarNode* diff, const Param& param, | |||||
const OperatorNodeConfig& config) | |||||
: Super({src->owner_graph(), config, "Softmax_backward", {src, diff}}, 0, | |||||
true) { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({src, diff}); | |||||
} | |||||
SymbolVar SoftmaxBackward::make( | |||||
SymbolVar src, SymbolVar diff, const Param& param, | |||||
const OperatorNodeConfig& config) { | |||||
auto out = src.node() | |||||
->owner_graph() | |||||
->insert_opr(std::make_unique<SoftmaxBackward>( | |||||
src.node(), diff.node(), param, config)) | |||||
->output(); | |||||
return out[0]; | |||||
} | |||||
void SoftmaxBackward::init_output_static_infer_desc() { | |||||
using namespace cg::static_infer; | |||||
auto&& mgr = owner_graph()->static_infer_manager(); | |||||
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); | |||||
this->init_output_static_infer_desc_workspace(false); | |||||
} | |||||
void SoftmaxBackward::init_output_dtype() { | |||||
output(0)->dtype(input(0)->dtype()); | |||||
} | |||||
size_t SoftmaxBackward::get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const { | |||||
return megdnn_opr()->get_workspace_in_bytes( | |||||
{input_shapes[0], input(0)->dtype(), input(0)->format()}, | |||||
{input_shapes[1], input(1)->dtype(), input(1)->format()}, | |||||
{output_shapes[0], output(0)->dtype(), output(0)->format()}); | |||||
} | |||||
void SoftmaxBackward::scn_do_execute() { | |||||
megdnn_opr()->exec( | |||||
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), | |||||
output(0)->dev_tensor().as_megdnn(), | |||||
intl::get_megdnn_workspace_from_var(output().back())); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,64 @@ | |||||
/** | |||||
* \file src/opr/include/megbrain/opr/dnn/softmax.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
#include "megdnn/oprs/nn.h" | |||||
namespace mgb { | |||||
namespace opr { | |||||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
SoftmaxForward, intl::MegDNNOprWrapperFwd<megdnn::SoftmaxForward>) // { | |||||
public: | |||||
MGE_WIN_DECLSPEC_FUC SoftmaxForward( | |||||
VarNode* src, const Param& param, const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
SymbolVar src, const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
private: | |||||
void get_output_var_shape( | |||||
const TensorShapeArray& inp_shape, | |||||
TensorShapeArray& out_shape) const override; | |||||
size_t get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const override; | |||||
void scn_do_execute() override; | |||||
}; | |||||
using Softmax = SoftmaxForward; | |||||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
SoftmaxBackward, intl::MegDNNOprWrapperBwd<megdnn::SoftmaxBackward>) // { | |||||
public: | |||||
MGE_WIN_DECLSPEC_FUC SoftmaxBackward( | |||||
VarNode* x, VarNode* y_grad, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
SymbolVar x, SymbolVar y_grad, const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
private: | |||||
void init_output_static_infer_desc() override; | |||||
void init_output_dtype() override; | |||||
size_t get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const override; | |||||
void scn_do_execute() override; | |||||
}; | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,65 @@ | |||||
/** | |||||
* \file src/opr/test/dnn/softmax.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/opr/dnn/softmax.h" | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/test/autocheck.h" | |||||
using namespace std; | |||||
using namespace mgb; | |||||
namespace { | |||||
using Param = opr::SoftmaxForward::Param; | |||||
void run(int32_t axis) { | |||||
using Checker = AutoOprChecker<1, 1>; | |||||
Param param{axis}; | |||||
auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||||
auto o0 = opr::SoftmaxForward::make(inputs[0], param); | |||||
return {o0}; | |||||
}; | |||||
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||||
auto opr = | |||||
MegDNNHandle::get(CompNodeEnv::from_comp_node(CompNode::default_cpu())) | |||||
->create_operator<megdnn::SoftmaxForward>(); | |||||
opr->param() = param; | |||||
dest[0].dtype(dtype::Float32()) | |||||
.comp_node(inp[0]->comp_node()) | |||||
.resize(inp[0]->shape()); | |||||
size_t wk_size = | |||||
opr->get_workspace_in_bytes(inp[0]->layout(), dest[0].layout()); | |||||
std::unique_ptr<dt_byte[]> wk_store{new dt_byte[wk_size]}; | |||||
opr->exec(inp[0]->as_megdnn(), dest[0].as_megdnn(), {wk_store.get(), wk_size}); | |||||
}; | |||||
auto gen = [&](HostTensorND& src) { | |||||
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> src_gen(10.f); | |||||
src = *src_gen(src.shape(), src.comp_node()); | |||||
}; | |||||
Checker::RunOptions opt; | |||||
opt.numdiff_max_err = 1e-4; | |||||
Checker checker{make_graph, fwd}; | |||||
checker.set_input_generator(0, gen); | |||||
checker.run({TensorShape{1, 2, 3, 4}}, opt) | |||||
.run({TensorShape{2, 3, 8, 8}}, opt) | |||||
.run({TensorShape{1, 3, 4, 4}}, opt); | |||||
} | |||||
} // anonymous namespace | |||||
TEST(TestOprDNN, SoftmaxForward) { | |||||
REQUIRE_GPU(1); | |||||
run(1); | |||||
} |
@@ -121,6 +121,7 @@ union OperatorParam { | |||||
param.RNNCell = 87, | param.RNNCell = 87, | ||||
param.RNN = 88, | param.RNN = 88, | ||||
param.LSTM = 89, | param.LSTM = 89, | ||||
param.Softmax = 90, | |||||
} | } | ||||
table Operator { | table Operator { | ||||