@@ -1442,6 +1442,39 @@ protected: | |||||
void backward_check_exec(const TensorLayout& src, const TensorLayout& dst); | void backward_check_exec(const TensorLayout& src, const TensorLayout& dst); | ||||
}; | }; | ||||
class LAMBUpdate : public OperatorBase { | |||||
DEF_OPR_PARAM(LAMBUpdate); | |||||
// input=(m_t-1,v_t-1,lamb_param,grad) , output = (m_t,v_t,new_param) | |||||
DEF_OPR_IMPL(LAMBUpdate, OperatorBase, 4, 3); | |||||
public: | |||||
virtual void exec( | |||||
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, | |||||
_megdnn_tensor_in lamb_param, _megdnn_tensor_in grad, | |||||
_megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||||
_megdnn_tensor_out new_param, _megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||||
const TensorLayout& lamb_param, const TensorLayout& grad, | |||||
const TensorLayout& m_t, const TensorLayout& v_t, | |||||
const TensorLayout& new_param) = 0; | |||||
void deduce_layout( | |||||
const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||||
const TensorLayout& lamb_param, const TensorLayout& grad, TensorLayout& m_t, | |||||
TensorLayout& v_t, TensorLayout& new_param); | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||||
const TensorLayout& lamb_param, const TensorLayout& grad, | |||||
const TensorLayout& m_t, const TensorLayout& v_t, | |||||
const TensorLayout& new_param, size_t workspace_in_bytes); | |||||
}; | |||||
using LAMB = LAMBUpdate; | |||||
} // namespace megdnn | } // namespace megdnn | ||||
#include "megdnn/internal/opr_header_epilogue.h" | #include "megdnn/internal/opr_header_epilogue.h" | ||||
@@ -36,13 +36,13 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
add_enum(Doc('Format', 'convolution data/filter/output format; see ' | add_enum(Doc('Format', 'convolution data/filter/output format; see ' | ||||
':class:`RelayoutFormat` for more details'), | ':class:`RelayoutFormat` for more details'), | ||||
'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', | 'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', | ||||
'NCHW44 = 7','NCHW44_DOT = 8', | |||||
'NCHW44 = 7','NCHW44_DOT = 8', | |||||
Doc('NCHW_WINOGRAD = 9', 'NCHW layout with weights tranformed by winograd'), | Doc('NCHW_WINOGRAD = 9', 'NCHW layout with weights tranformed by winograd'), | ||||
Doc('NCHW88_WINOGRAD = 10', 'NCHW88 layout with weights tranformed by winograd'), | Doc('NCHW88_WINOGRAD = 10', 'NCHW88 layout with weights tranformed by winograd'), | ||||
Doc('NCHW44_WINOGRAD = 11', 'NCHW44 layout with weights tranformed by winograd'), | Doc('NCHW44_WINOGRAD = 11', 'NCHW44 layout with weights tranformed by winograd'), | ||||
Doc('NCHW4_NCHW32 = 12', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||||
Doc('NCHW32_NCHW4 = 13', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||||
Doc('NCHW4_NCHW = 14', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | |||||
Doc('NCHW4_NCHW32 = 12', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||||
Doc('NCHW32_NCHW4 = 13', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||||
Doc('NCHW4_NCHW = 14', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | |||||
Doc('NHWC_NCHW = 15', 'NHWC_NCHW means input tensors are nhwc layout, ' | Doc('NHWC_NCHW = 15', 'NHWC_NCHW means input tensors are nhwc layout, ' | ||||
'output tensor is nchw layout'), | 'output tensor is nchw layout'), | ||||
Doc('NHWC_NCHW4_IC_SMALL = 16', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' | Doc('NHWC_NCHW4_IC_SMALL = 16', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' | ||||
@@ -96,9 +96,9 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
add_enum(Doc('Format', 'convolution data/filter/output format; see ' | add_enum(Doc('Format', 'convolution data/filter/output format; see ' | ||||
':class:`RelayoutFormat` for more details'), | ':class:`RelayoutFormat` for more details'), | ||||
'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', | 'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', | ||||
'NCHW44 = 7','NCHW44_DOT = 8', | |||||
Doc('NCHW4_NCHW32 = 9', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||||
Doc('NCHW32_NCHW4 = 10', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||||
'NCHW44 = 7','NCHW44_DOT = 8', | |||||
Doc('NCHW4_NCHW32 = 9', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||||
Doc('NCHW32_NCHW4 = 10', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||||
Doc('NCHW4_NCHW = 11', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | Doc('NCHW4_NCHW = 11', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | ||||
Doc('NHWC_NCHW = 12', 'NHWC_NCHW means input tensors are nhwc layout, ' | Doc('NHWC_NCHW = 12', 'NHWC_NCHW means input tensors are nhwc layout, ' | ||||
'output tensor is nchw layout'), | 'output tensor is nchw layout'), | ||||
@@ -107,9 +107,9 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
Doc('NCHW_NCHW4_IC_SMALL = 14', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, ' | Doc('NCHW_NCHW4_IC_SMALL = 14', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, ' | ||||
'output tensor is nchw4 layout, padding c=4'), | 'output tensor is nchw4 layout, padding c=4'), | ||||
Doc('CHWN4 = 15', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' | Doc('CHWN4 = 15', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' | ||||
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'), | |||||
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'), | |||||
Doc('NCHW64 = 16', 'NCHW64 is designed for convolution implementation to utilizing TensorCore ' | Doc('NCHW64 = 16', 'NCHW64 is designed for convolution implementation to utilizing TensorCore ' | ||||
'instructions for 4-bit integers on Nvidia platforms'), | |||||
'instructions for 4-bit integers on Nvidia platforms'), | |||||
Doc('NCHW4_NHWC = 17', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout')). | Doc('NCHW4_NHWC = 17', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout')). | ||||
add_enum_alias('ComputeMode', 'ConvolutionV1',name_field='compute_mode') | add_enum_alias('ComputeMode', 'ConvolutionV1',name_field='compute_mode') | ||||
) | ) | ||||
@@ -1038,10 +1038,10 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||||
'NCHW_NCHW4 = 24', | 'NCHW_NCHW4 = 24', | ||||
'NCHW4_NCHW = 25', | 'NCHW4_NCHW = 25', | ||||
'NCHW_NCHW4_WEIGHT = 26', | 'NCHW_NCHW4_WEIGHT = 26', | ||||
'NCHW_NCHW64 = 27', | |||||
'NCHW64_NCHW = 28', | |||||
'NCHW_NHWC = 29', | |||||
'NHWC_NCHW = 30', | |||||
'NCHW_NCHW64 = 27', | |||||
'NCHW64_NCHW = 28', | |||||
'NCHW_NHWC = 29', | |||||
'NHWC_NCHW = 30', | |||||
'NHWCD4I_NHWC = 31', | 'NHWCD4I_NHWC = 31', | ||||
) | ) | ||||
) | ) | ||||
@@ -1264,3 +1264,14 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | |||||
add_fields('float32', Doc('dropout', 'If introduce a Dropout layer on the outputs of each LSTM layer'), '0.f'). | add_fields('float32', Doc('dropout', 'If introduce a Dropout layer on the outputs of each LSTM layer'), '0.f'). | ||||
add_enum_alias('FwdMode', 'BN', name_field='fwd_mode') | add_enum_alias('FwdMode', 'BN', name_field='fwd_mode') | ||||
) | ) | ||||
(pdef('LAMBUpdate'). | |||||
add_fields('float32', Doc('beta_1', 'beta_1 paramter of lamb'), '1.f'). | |||||
add_fields('float32', Doc('beta_2', 'beta_2 paramter of lamb'), '1.f'). | |||||
add_fields('float32', Doc('step', 'training step'), '1.f'). | |||||
add_fields('float32', Doc('lr', 'learning rate'), '1.f'). | |||||
add_fields('float32', Doc('weight_decay', 'weight decay to adjust learning rate'), '1.f'). | |||||
add_fields('float32', Doc('eps', 'eps to multi'), '1.f'). | |||||
add_fields('bool', Doc('bias_correction', 'whether correct bias'), 'true'). | |||||
add_fields('bool', Doc('always_adapt', 'apply adaptive lr to 0.0'), 'false') | |||||
) |
@@ -209,6 +209,7 @@ private: | |||||
cb(RNN) \ | cb(RNN) \ | ||||
cb(RNNBackward) \ | cb(RNNBackward) \ | ||||
cb(LSTM) \ | cb(LSTM) \ | ||||
cb(LAMBUpdate) \ | |||||
cb(LSTMBackward) \ | cb(LSTMBackward) \ | ||||
cb(SoftmaxForward) \ | cb(SoftmaxForward) \ | ||||
cb(SoftmaxBackward) | cb(SoftmaxBackward) | ||||
@@ -0,0 +1,25 @@ | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
void LAMBUpdate::deduce_layout( | |||||
const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||||
const TensorLayout& lamb_param, const TensorLayout& grad, TensorLayout& m_t, | |||||
TensorLayout& v_t, TensorLayout& new_param) { | |||||
m_t = TensorLayout(m_t_1); | |||||
v_t = TensorLayout(v_t_1); | |||||
new_param = TensorLayout(lamb_param); | |||||
MEGDNN_MARK_USED_VAR(grad); | |||||
} | |||||
void LAMBUpdate::check_exec( | |||||
const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||||
const TensorLayout& lamb_param, const TensorLayout& grad, | |||||
const TensorLayout& m_t, const TensorLayout& v_t, const TensorLayout& new_param, | |||||
size_t workspace_in_bytes) { | |||||
auto required_workspace_in_bytes = | |||||
get_workspace_in_bytes(m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
} | |||||
} // namespace megdnn |
@@ -127,6 +127,7 @@ DEF(LSQBackward, 7, true, false); | |||||
DEF(Fill, 1, true, false); | DEF(Fill, 1, true, false); | ||||
DEF(LayerNormForward, 6, true, true); | DEF(LayerNormForward, 6, true, true); | ||||
DEF(LayerNormBackward, 8, true, true); | DEF(LayerNormBackward, 8, true, true); | ||||
DEF(LAMBUpdate, 7, true, true); | |||||
DEF(DropoutForward, 3, true, true); | DEF(DropoutForward, 3, true, true); | ||||
DEF(DropoutBackward, 3, true, true); | DEF(DropoutBackward, 3, true, true); | ||||
DEF(RNNCellForward, 7, true, true); | DEF(RNNCellForward, 7, true, true); | ||||
@@ -35,6 +35,7 @@ | |||||
#include "src/cuda/images2neibs/opr_impl.h" | #include "src/cuda/images2neibs/opr_impl.h" | ||||
#include "src/cuda/indexing_multi_axis_vec/opr_impl.h" | #include "src/cuda/indexing_multi_axis_vec/opr_impl.h" | ||||
#include "src/cuda/indexing_one_hot/opr_impl.h" | #include "src/cuda/indexing_one_hot/opr_impl.h" | ||||
#include "src/cuda/lamb/opr_impl.h" | |||||
#include "src/cuda/layer_norm/opr_impl.h" | #include "src/cuda/layer_norm/opr_impl.h" | ||||
#include "src/cuda/linspace/opr_impl.h" | #include "src/cuda/linspace/opr_impl.h" | ||||
#include "src/cuda/local/opr_impl.h" | #include "src/cuda/local/opr_impl.h" | ||||
@@ -210,6 +211,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingBackward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingBackward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormBackward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormBackward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LAMBUpdate); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxForward); | ||||
@@ -0,0 +1,102 @@ | |||||
#include <thrust/device_vector.h> | |||||
#include <thrust/pair.h> | |||||
#include <thrust/transform_reduce.h> | |||||
#include <thrust/tuple.h> | |||||
#include <cfloat> | |||||
#include "megdnn/arch.h" | |||||
#include "megdnn/dtype.h" | |||||
#include "src/cuda/cuda_shfl_compat.cuh" | |||||
#include "src/cuda/lamb/lamb_cuda.cuh" | |||||
#include "src/cuda/utils.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace lamb { | |||||
template <typename T> | |||||
struct square { | |||||
__host__ __device__ T operator()(const T& x) const { return x * x; } | |||||
}; | |||||
template <typename T, typename T_ACC> | |||||
__global__ void update_kernal_1( | |||||
T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, | |||||
T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, | |||||
float weight_decay, float eps, bool bias_correction, bool always_adapt, | |||||
size_t total_nr_elem) { | |||||
size_t idx = threadIdx.x + blockIdx.x * blockDim.x; | |||||
T_ACC bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1, | |||||
bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1; | |||||
if (idx < total_nr_elem) { | |||||
m_t[idx] = beta_1 * m_t_1[idx] + (1 - beta_1) * static_cast<T_ACC>(grad[idx]); | |||||
v_t[idx] = beta_2 * v_t_1[idx] + | |||||
(1 - beta_2) * std::pow(static_cast<T_ACC>(grad[idx]), 2); | |||||
rt[idx] = (m_t[idx] / bc_1) / (std::sqrt(v_t[idx] / bc_2) + eps); | |||||
if (weight_decay != 0) { | |||||
rt[idx] += lamb_param[idx] * weight_decay; | |||||
} | |||||
} | |||||
} | |||||
template <typename T, typename T_ACC> | |||||
__global__ void update_kernal_2( | |||||
T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, | |||||
T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, | |||||
float weight_decay, float eps, bool bias_correction, bool always_adapt, | |||||
size_t total_nr_elem, T_ACC trust_ratio) { | |||||
size_t idx = threadIdx.x + blockIdx.x * blockDim.x; | |||||
T_ACC bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1, | |||||
bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1; | |||||
if (idx < total_nr_elem) { | |||||
rt[idx] = (m_t[idx] / bc_1) / (std::sqrt(v_t[idx] / bc_2) + eps); | |||||
if (weight_decay != 0) { | |||||
rt[idx] += lamb_param[idx] * weight_decay; | |||||
} | |||||
new_param[idx] = lamb_param[idx] - lr * trust_ratio * rt[idx]; | |||||
} | |||||
} | |||||
template <typename T, typename T_ACC> | |||||
void update( | |||||
T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, | |||||
T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, | |||||
float weight_decay, float eps, bool bias_correction, bool always_adapt, | |||||
size_t total_nr_elem, cudaStream_t stream) { | |||||
size_t NR_BLOCKS = DIVUP(total_nr_elem, NR_THREADS); | |||||
update_kernal_1<T, T_ACC><<<NR_BLOCKS, NR_THREADS, 0, stream>>>( | |||||
m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, rt, beta_1, beta_2, | |||||
step, lr, weight_decay, eps, bias_correction, always_adapt, total_nr_elem); | |||||
after_kernel_launch(); | |||||
thrust::device_ptr<T_ACC> lamb_param_ptr(lamb_param); | |||||
thrust::device_ptr<T_ACC> rt_ptr(rt); | |||||
square<T_ACC> unary_op; | |||||
thrust::plus<T_ACC> binary_op; | |||||
T_ACC p_norm = std::sqrt(thrust::transform_reduce( | |||||
lamb_param_ptr, lamb_param_ptr + total_nr_elem, unary_op, 0.f, binary_op)); | |||||
T_ACC d_norm = std::sqrt(thrust::transform_reduce( | |||||
rt_ptr, rt_ptr + total_nr_elem, unary_op, 0.f, binary_op)); | |||||
T_ACC trust_ratio = 1; | |||||
if ((always_adapt || weight_decay > 0) && p_norm > 0 && d_norm > 0) { | |||||
trust_ratio = p_norm / d_norm; | |||||
} | |||||
update_kernal_2<T, T_ACC><<<NR_BLOCKS, NR_THREADS, 0, stream>>>( | |||||
m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, rt, beta_1, beta_2, | |||||
step, lr, weight_decay, eps, bias_correction, always_adapt, total_nr_elem, | |||||
trust_ratio); | |||||
after_kernel_launch(); | |||||
} | |||||
#define INST(T, T_ACC) \ | |||||
template void update<T, T_ACC>( \ | |||||
T_ACC*, T_ACC*, T_ACC*, T*, T_ACC*, T_ACC*, T_ACC*, T_ACC*, float, float, \ | |||||
float, float, float, float, bool, bool, size_t, cudaStream_t); | |||||
INST(dt_float32, dt_float32) | |||||
INST(dt_float16, dt_float32) | |||||
INST(dt_bfloat16, dt_float32) | |||||
#undef INST | |||||
} // namespace lamb | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,17 @@ | |||||
#pragma once | |||||
#include <cuda_runtime_api.h> | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace lamb { | |||||
template <typename T, typename T_ACC> | |||||
void update( | |||||
T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, | |||||
T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, | |||||
float weight_decay, float eps, bool bias_correction, bool always_adapt, | |||||
size_t total_nr_elem, cudaStream_t stream); | |||||
} // namespace lamb | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,45 @@ | |||||
#include "src/cuda/lamb/opr_impl.h" | |||||
#include "./lamb_cuda.cuh" | |||||
#include "src/cuda/utils.h" | |||||
#include <cmath> | |||||
#include <functional> | |||||
#include <numeric> | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
void LAMBUpdateImpl::exec( | |||||
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param, | |||||
_megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||||
_megdnn_tensor_out new_param, _megdnn_workspace workspace) { | |||||
auto p = param(); | |||||
float beta_1 = p.beta_1; | |||||
float beta_2 = p.beta_2; | |||||
float step = p.step; | |||||
float lr = p.lr; | |||||
float weight_decay = p.weight_decay; | |||||
float eps = p.eps; | |||||
bool bias_correction = p.bias_correction; | |||||
bool always_adapt = p.always_adapt; | |||||
size_t total_elem = lamb_param.layout.total_nr_elems(); | |||||
auto stream = cuda_stream(handle()); | |||||
using namespace ::megdnn::cuda::lamb; | |||||
#define cb(DType) \ | |||||
if (grad.layout.dtype == DType()) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | |||||
using T_ACC = float; \ | |||||
update<T, T_ACC>( \ | |||||
m_t_1.ptr<T_ACC>(), v_t_1.ptr<T_ACC>(), lamb_param.ptr<T_ACC>(), \ | |||||
grad.ptr<T>(), m_t.ptr<T_ACC>(), v_t.ptr<T_ACC>(), \ | |||||
new_param.ptr<T_ACC>(), workspace.ptr<T_ACC>(), beta_1, beta_2, step, \ | |||||
lr, weight_decay, eps, bias_correction, always_adapt, total_elem, \ | |||||
stream); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,25 @@ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/cuda/cudnn_wrapper.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
class LAMBUpdateImpl final : public LAMBUpdate { | |||||
public: | |||||
using LAMBUpdate::LAMBUpdate; | |||||
void exec( | |||||
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, | |||||
_megdnn_tensor_in lamb_param, _megdnn_tensor_in grad, | |||||
_megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||||
_megdnn_tensor_out new_param, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||||
const TensorLayout& lamb_param, const TensorLayout& grad, | |||||
const TensorLayout& m_t, const TensorLayout& v_t, | |||||
const TensorLayout& new_param) override { | |||||
return m_t.access_bytes(); | |||||
}; | |||||
}; | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -37,6 +37,7 @@ | |||||
#include "src/naive/images2neibs/opr_impl.h" | #include "src/naive/images2neibs/opr_impl.h" | ||||
#include "src/naive/indexing_multi_axis_vec/opr_impl.h" | #include "src/naive/indexing_multi_axis_vec/opr_impl.h" | ||||
#include "src/naive/indexing_one_hot/opr_impl.h" | #include "src/naive/indexing_one_hot/opr_impl.h" | ||||
#include "src/naive/lamb/opr_impl.h" | |||||
#include "src/naive/layer_norm/opr_impl.h" | #include "src/naive/layer_norm/opr_impl.h" | ||||
#include "src/naive/linspace/opr_impl.h" | #include "src/naive/linspace/opr_impl.h" | ||||
#include "src/naive/local/opr_impl.h" | #include "src/naive/local/opr_impl.h" | ||||
@@ -0,0 +1,89 @@ | |||||
#include "src/naive/lamb/opr_impl.h" | |||||
#include <cmath> | |||||
#include <functional> | |||||
#include <numeric> | |||||
#include "src/common/utils.h" | |||||
#include "src/naive/handle.h" | |||||
using namespace megdnn; | |||||
using namespace naive; | |||||
namespace { | |||||
using Param = megdnn::LAMBUpdate::Param; | |||||
template <typename T, typename T_ACC = float> | |||||
void update( | |||||
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param, | |||||
_megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||||
_megdnn_tensor_out new_param, const Param& param) { | |||||
float beta_1 = param.beta_1; | |||||
float beta_2 = param.beta_2; | |||||
float step = param.step; | |||||
float lr = param.lr; | |||||
float weight_decay = param.weight_decay; | |||||
float eps = param.eps; | |||||
bool bias_correction = param.bias_correction; | |||||
bool always_adapt = param.always_adapt; | |||||
size_t total_elem = lamb_param.layout.total_nr_elems(); | |||||
T_ACC mt, vt, bc_1, bc_2, rt, d_norm = 0; | |||||
bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1; | |||||
bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1; | |||||
for (size_t i = 0; i < total_elem; i++) { | |||||
mt = m_t.ptr<T_ACC>()[i] = beta_1 * m_t_1.ptr<T_ACC>()[i] + | |||||
(1 - beta_1) * static_cast<T_ACC>(grad.ptr<T>()[i]); | |||||
vt = v_t.ptr<T_ACC>()[i] = | |||||
beta_2 * v_t_1.ptr<T_ACC>()[i] + | |||||
(1 - beta_2) * std::pow(static_cast<T_ACC>(grad.ptr<T>()[i]), 2); | |||||
rt = (mt / bc_1) / (sqrt(vt / bc_2) + eps); | |||||
if (weight_decay != 0) { | |||||
rt += lamb_param.ptr<T_ACC>()[i] * weight_decay; | |||||
} | |||||
d_norm += rt * rt; | |||||
} | |||||
d_norm = sqrt(d_norm); | |||||
auto get_norm = [=](_megdnn_tensor_in norm) -> T_ACC { | |||||
return sqrt(std::accumulate( | |||||
norm.ptr<T_ACC>(), norm.ptr<T_ACC>() + total_elem, 0, | |||||
[](T_ACC t1, T_ACC t2) -> T_ACC { return t1 + t2 * t2; })); | |||||
}; | |||||
T_ACC p_norm = get_norm(lamb_param), trust_ratio = 1; | |||||
if ((always_adapt || weight_decay > 0) && p_norm > 0 && d_norm > 0) { | |||||
trust_ratio = p_norm / d_norm; | |||||
} | |||||
for (size_t i = 0; i < total_elem; i++) { | |||||
mt = m_t.ptr<T_ACC>()[i]; | |||||
vt = v_t.ptr<T_ACC>()[i]; | |||||
rt = (mt / bc_1) / (sqrt(vt / bc_2) + eps); | |||||
if (weight_decay != 0) { | |||||
rt += lamb_param.ptr<T_ACC>()[i] * weight_decay; | |||||
} | |||||
new_param.ptr<T_ACC>()[i] = lamb_param.ptr<T_ACC>()[i] - lr * trust_ratio * rt; | |||||
} | |||||
} | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace naive { | |||||
void LAMBUpdateImpl::exec( | |||||
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param, | |||||
_megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||||
_megdnn_tensor_out new_param, _megdnn_workspace workspace) { | |||||
check_exec( | |||||
m_t_1.layout, v_t_1.layout, lamb_param.layout, grad.layout, m_t.layout, | |||||
v_t.layout, new_param.layout, workspace.size); | |||||
#define cb(DType) \ | |||||
if (grad.layout.dtype == DType()) { \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(update<typename DTypeTrait<DType>::ctype>( \ | |||||
m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, param())); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,34 @@ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
class LAMBUpdateImpl final : public LAMBUpdate { | |||||
public: | |||||
using LAMBUpdate::LAMBUpdate; | |||||
void exec( | |||||
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, | |||||
_megdnn_tensor_in lamb_param, _megdnn_tensor_in grad, | |||||
_megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||||
_megdnn_tensor_out new_param, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||||
const TensorLayout& lamb_param, const TensorLayout& grad, | |||||
const TensorLayout& m_t, const TensorLayout& v_t, | |||||
const TensorLayout& new_param) override { | |||||
MEGDNN_MARK_USED_VAR(m_t_1); | |||||
MEGDNN_MARK_USED_VAR(v_t_1); | |||||
MEGDNN_MARK_USED_VAR(lamb_param); | |||||
MEGDNN_MARK_USED_VAR(grad); | |||||
MEGDNN_MARK_USED_VAR(m_t); | |||||
MEGDNN_MARK_USED_VAR(v_t); | |||||
MEGDNN_MARK_USED_VAR(new_param); | |||||
return 0; | |||||
}; | |||||
}; | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,36 @@ | |||||
#pragma once | |||||
#include "megdnn/basic_types.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
namespace lamb { | |||||
struct TestArg { | |||||
param::LAMBUpdate param; | |||||
TensorShape src; | |||||
TestArg(param::LAMBUpdate param, TensorShape src) : param(param), src(src) {} | |||||
}; | |||||
inline std::vector<TestArg> get_args() { | |||||
std::vector<TestArg> args; | |||||
param::LAMBUpdate cur_param; | |||||
cur_param.beta_1 = 0.9; | |||||
cur_param.beta_2 = 0.999; | |||||
cur_param.eps = 1e-8; | |||||
cur_param.weight_decay = 0; | |||||
cur_param.lr = 6.25e-5; | |||||
cur_param.bias_correction = true; | |||||
cur_param.always_adapt = false; | |||||
args.emplace_back( | |||||
cur_param, TensorShape{ | |||||
1280, | |||||
}); | |||||
args.emplace_back(cur_param, TensorShape{1280, 1280}); | |||||
args.emplace_back(cur_param, TensorShape{1280, 3, 224, 224}); | |||||
return args; | |||||
} | |||||
} // namespace lamb | |||||
} // namespace test | |||||
} // namespace megdnn |
@@ -0,0 +1,44 @@ | |||||
#include "test/cuda/fixture.h" | |||||
#include "test/common/checker.h" | |||||
#include "test/common/rng.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
TEST_F(CUDA, LAMBUpdate) { | |||||
LAMBUpdate::Param param; | |||||
param.beta_1 = 0.9; | |||||
param.beta_2 = 0.999; | |||||
param.eps = 1e-5; | |||||
param.weight_decay = 0.4; | |||||
param.lr = 1e-3; | |||||
param.step = 1; | |||||
param.bias_correction = true; | |||||
param.always_adapt = false; | |||||
Checker<LAMBUpdate> checker(handle_cuda()); | |||||
checker.set_epsilon(1e-3); | |||||
UniformFloatRNG rng0(0, 1); | |||||
auto run = [&](DType d) { | |||||
checker.set_param(param) | |||||
.set_rng(0, &rng0) | |||||
.set_rng(1, &rng0) | |||||
.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.set_dtype(3, d) | |||||
.set_dtype(4, dtype::Float32()) | |||||
.set_dtype(5, dtype::Float32()) | |||||
.set_dtype(6, dtype::Float32()) | |||||
.execs({{2}, {2}, {2}, {2}, {}, {}, {}}); | |||||
}; | |||||
run(dtype::Float32()); | |||||
run(dtype::Float16()); | |||||
run(dtype::BFloat16()); | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn |
@@ -0,0 +1,33 @@ | |||||
#include "test/common/lamb.h" | |||||
#include "megdnn/dtype.h" | |||||
#include "megdnn/oprs.h" | |||||
#include "test/common/checker.h" | |||||
#include "test/naive/fixture.h" | |||||
using namespace megdnn; | |||||
using namespace test; | |||||
TEST_F(NAIVE, LAMBUpdate) { | |||||
Checker<LAMBUpdate> checker(handle(), false); | |||||
LAMBUpdate::Param param; | |||||
param.beta_1 = 0; | |||||
param.beta_2 = 0; | |||||
param.eps = 0; | |||||
param.weight_decay = 0; | |||||
param.lr = 1; | |||||
param.step = 1; | |||||
param.bias_correction = true; | |||||
param.always_adapt = false; | |||||
TensorND m_t_1 = TensorValue({2}, dtype::Float32(), {1, 1}); | |||||
TensorND v_t_1 = TensorValue({2}, dtype::Float32(), {1, 1}); | |||||
TensorND param_lamb = TensorValue({2}, dtype::Float32(), {1, 1}); | |||||
TensorND grad = TensorValue({2}, dtype::Float16(), {1, 1}); | |||||
TensorND m_t = TensorValue({2}, dtype::Float32(), {1, 1}); | |||||
TensorND v_t = TensorValue({2}, dtype::Float32(), {1, 1}); | |||||
TensorND new_param = TensorValue({2}, dtype::Float32(), {0, 0}); | |||||
checker.set_param(param).exect( | |||||
Testcase{m_t_1, v_t_1, param_lamb, grad, {}, {}, {}}, | |||||
Testcase{{}, {}, {}, {}, m_t, v_t, new_param}); | |||||
} |
@@ -4,6 +4,7 @@ from .adagrad import Adagrad | |||||
from .adam import Adam | from .adam import Adam | ||||
from .adamw import AdamW | from .adamw import AdamW | ||||
from .clip_grad import * | from .clip_grad import * | ||||
from .lamb import LAMB, LAMBFp16 | |||||
from .lr_scheduler import LRScheduler | from .lr_scheduler import LRScheduler | ||||
from .multi_step_lr import MultiStepLR | from .multi_step_lr import MultiStepLR | ||||
from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
@@ -0,0 +1,160 @@ | |||||
# Copyright (c) 2020 Ross Wightman | |||||
# This file has been modified by Megvii ("Megvii Modifications"). | |||||
# All Megvii Modifications are Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
"""LAMB optimizer | |||||
References: https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py | |||||
""" | |||||
import os | |||||
from typing import Iterable, Tuple, Union | |||||
from megengine.core._imperative_rt.core2 import apply | |||||
from megengine.core.ops.builtin import LAMBUpdate | |||||
from .. import Parameter, tensor | |||||
from ..functional import sum | |||||
from ..functional.inplace import _inplace_add_ | |||||
from .optimizer import Optimizer | |||||
class LAMB(Optimizer): | |||||
r"""Implements LAMB algorithm. | |||||
LAMB is proposed in `"Large Batch Optimization for Deep Learning: Training BERT in 76 minutes" | |||||
<https://arxiv.org/abs/1904.00962>`_. | |||||
Args: | |||||
params: iterable of parameters to optimize or dicts defining parameter groups. | |||||
lr: learning rate. | |||||
betas: coefficients used for computing running averages of gradient and its square. | |||||
Default: ``(0.9, 0.999)`` | |||||
eps: term added to the denominator to improve numerical stability. Default: ``1e-8`` | |||||
bias_correction: enables bias correction by ``1 - beta ** step``. Default: ``True`` | |||||
weight_decay: weight decay (L2 penalty). Default: ``0.0`` | |||||
always_adapt: apply adaptive lr to ``0.0`` weight decay parameter. Default: ``False`` | |||||
""" | |||||
def __init__( | |||||
self, | |||||
params: Union[Iterable[Parameter], dict], | |||||
lr: float, | |||||
betas: Tuple[float, float] = (0.9, 0.999), | |||||
eps: float = 1e-8, | |||||
bias_correction: bool = True, | |||||
weight_decay: float = 0.0, | |||||
always_adapt: bool = False, | |||||
): | |||||
if lr < 0.0: | |||||
raise ValueError("Invalid learning rate: {}".format(lr)) | |||||
if weight_decay < 0.0: | |||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | |||||
if not 0.0 <= betas[0] < 1.0: | |||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | |||||
if not 0.0 <= betas[1] < 1.0: | |||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | |||||
defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) | |||||
super().__init__(params, defaults) | |||||
self.bias_correction = bias_correction | |||||
self.always_adapt = always_adapt | |||||
self._disable_type_convert = True | |||||
def _create_state(self, param_group): | |||||
for param in param_group["params"]: | |||||
self._add_state(param, "exp_avg") | |||||
self._add_state(param, "exp_avg_sq") | |||||
self._add_state(param, "step", initializer=0.0, dtype="float32") | |||||
def _updates(self, param_group): | |||||
lr = param_group["lr"] | |||||
weight_decay = param_group["weight_decay"] | |||||
eps = param_group["eps"] | |||||
beta0, beta1 = param_group["betas"] | |||||
# since `conver_inputs` is disabled for param updates, | |||||
# scalar should be explicitly tansforred to tensor | |||||
c1 = tensor(1.0) | |||||
for param in param_group["params"]: | |||||
if param.grad is None: | |||||
continue | |||||
grad = param.grad | |||||
states = self._state[param] | |||||
step, exp_avg, exp_avg_sq = ( | |||||
states["step"], | |||||
states["exp_avg"], | |||||
states["exp_avg_sq"], | |||||
) | |||||
step += c1 | |||||
op = LAMBUpdate( | |||||
beta0, | |||||
beta1, | |||||
int(step), | |||||
lr, | |||||
weight_decay, | |||||
eps, | |||||
self.bias_correction, | |||||
self.always_adapt, | |||||
) | |||||
new_exp_avg, new_exp_avg_sq, new_param = apply( | |||||
op, exp_avg, exp_avg_sq, param, grad | |||||
) | |||||
param._reset(new_param) | |||||
exp_avg._reset(new_exp_avg) | |||||
exp_avg_sq._reset(new_exp_avg_sq) | |||||
class LAMBFp16(LAMB): | |||||
def _create_state(self, param_group): | |||||
for param in param_group["params"]: | |||||
self._add_state(param, "exp_avg", dtype="float32") | |||||
self._add_state(param, "exp_avg_sq", dtype="float32") | |||||
self._add_state(param, "step", initializer=0.0, dtype="float32") | |||||
self._state[param]["param_fp32"] = param.astype("float32") | |||||
def _updates(self, param_group): | |||||
lr = param_group["lr"] | |||||
weight_decay = param_group["weight_decay"] | |||||
eps = param_group["eps"] | |||||
beta0, beta1 = param_group["betas"] | |||||
c1 = tensor(1.0) | |||||
for param in param_group["params"]: | |||||
if param.grad is None: | |||||
continue | |||||
grad = param.grad | |||||
states = self._state[param] | |||||
step, exp_avg, exp_avg_sq = ( | |||||
states["step"], | |||||
states["exp_avg"], | |||||
states["exp_avg_sq"], | |||||
) | |||||
step += c1 | |||||
fp32_param = states["param_fp32"] | |||||
op = LAMBUpdate( | |||||
beta0, | |||||
beta1, | |||||
step, | |||||
lr, | |||||
weight_decay, | |||||
eps, | |||||
self.bias_correction, | |||||
self.always_adapt, | |||||
) | |||||
new_exp_avg, new_exp_avg_sq, new_param = apply( | |||||
op, exp_avg, exp_avg_sq, fp32_param, grad | |||||
) | |||||
fp32_param._reset(new_param) | |||||
param._reset(new_param.astype("float16")) | |||||
exp_avg._reset(new_exp_avg) | |||||
exp_avg_sq._reset(new_exp_avg_sq) |
@@ -0,0 +1,85 @@ | |||||
import numpy as np | |||||
import megengine as mge | |||||
import megengine.autodiff as ad | |||||
import megengine.functional as F | |||||
import megengine.module as M | |||||
import megengine.optimizer as optim | |||||
from megengine import tensor | |||||
from megengine.core._imperative_rt.core2 import apply | |||||
from megengine.core.ops.builtin import LAMBUpdate | |||||
def lamb_update( | |||||
param_group, step, exp_avg, exp_avg_sq, param, grad, bias_correction, always_adapt | |||||
): | |||||
lr = param_group["lr"] | |||||
weight_decay = param_group["weight_decay"] | |||||
eps = param_group["eps"] | |||||
beta0, beta1 = param_group["betas"] | |||||
# since `conver_inputs` is disabled for param updates, | |||||
# scalar should be explicitly tansforred to tensor | |||||
_lr, _neg_lr = map(tensor, (lr, -lr)) | |||||
_weight_decay = tensor(weight_decay) | |||||
_eps = tensor(eps) | |||||
_beta0, _beta1 = map(tensor, (beta0, beta1)) | |||||
c1, c05, c0 = map(tensor, (1.0, 0.5, 0.0)) | |||||
def norm(vec): | |||||
return sum(vec * vec) ** c05 | |||||
p_norm = norm(param.flatten()) | |||||
# step = step + c1 | |||||
step += c1 | |||||
# exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0) | |||||
exp_avg *= _beta0 | |||||
exp_avg += grad * (c1 - _beta0) | |||||
# exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad) | |||||
exp_avg_sq *= _beta1 | |||||
exp_avg_sq += (c1 - _beta1) * (grad * grad) | |||||
bias_correction1 = c1 - _beta0 ** step if bias_correction else c1 | |||||
bias_correction2 = c1 - _beta1 ** step if bias_correction else c1 | |||||
delta = (exp_avg / bias_correction1) / ( | |||||
(exp_avg_sq / bias_correction2) ** c05 + _eps | |||||
) | |||||
if weight_decay != 0.0: | |||||
delta += param * _weight_decay | |||||
d_norm = norm(delta.flatten()) | |||||
trust_ratio = ( | |||||
p_norm / d_norm | |||||
if (always_adapt or weight_decay > 0) and p_norm > c0 and d_norm > c0 | |||||
else c1 | |||||
) | |||||
new_param = param - _lr * trust_ratio * delta | |||||
return exp_avg, exp_avg_sq, new_param | |||||
def test_lamb(): | |||||
op = LAMBUpdate(0.9, 0.999, 1, 1e-3, 0.4, 1e-8, True, False) | |||||
m_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32) | |||||
v_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32) | |||||
params = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32) | |||||
grad = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float16) | |||||
(new_m_t, new_v_t, new_param) = apply(op, m_t_1, v_t_1, params, grad) | |||||
param_group = { | |||||
"betas": (0.9, 0.999), | |||||
"step": 1, | |||||
"lr": 1e-3, | |||||
"weight_decay": 0.4, | |||||
"eps": 1e-8, | |||||
} | |||||
gt_m_t, gt_v_t, gt_new_param = lamb_update( | |||||
param_group, 1, m_t_1, v_t_1, params, grad, True, False | |||||
) | |||||
np.testing.assert_allclose(new_m_t.numpy(), gt_m_t.numpy(), atol=1e-2) | |||||
np.testing.assert_allclose(new_v_t.numpy(), gt_v_t.numpy(), atol=1e-2) | |||||
np.testing.assert_allclose(new_param.numpy(), gt_new_param.numpy(), atol=1e-2) |
@@ -0,0 +1,82 @@ | |||||
#include "megbrain/imperative/opr_utility.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/opr/basic_arith.h" | |||||
#include "megbrain/opr/utility.h" | |||||
#include "../blob_manager_impl.h" | |||||
#include "../dnn_op_helper.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
namespace { | |||||
namespace lamb { | |||||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||||
return layout_checker; | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||||
mgb_assert(input_descs.size() == 4, "IndexingOneHot expects 4inputs"); | |||||
auto comp_node = input_descs[0].comp_node; | |||||
auto comp_node1 = input_descs[1].comp_node; | |||||
auto comp_node2 = input_descs[2].comp_node; | |||||
TensorLayout m_t_1 = input_descs[0].layout, v_t_1 = input_descs[1].layout, | |||||
lamb_param = input_descs[2].layout, grad = input_descs[3].layout; | |||||
TensorLayout new_param = lamb_param, m_t = m_t_1, v_t = v_t_1; | |||||
return {{{m_t, comp_node}, {v_t, comp_node1}, {new_param, comp_node2}}, true}; | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
auto&& op = def.cast_final_safe<LAMBUpdate>(); | |||||
auto&& m_t_1 = inputs[0]; | |||||
auto&& v_t_1 = inputs[1]; | |||||
auto&& lamb_param = inputs[2]; | |||||
auto&& grad = inputs[3]; | |||||
TensorLayout m_t_1_layout{m_t_1->layout()}; | |||||
TensorLayout v_t_1_layout{v_t_1->layout()}; | |||||
TensorLayout lamb_param_layout{lamb_param->layout()}; | |||||
DeviceTensorND m_t = BlobManager::inst()->alloc_workspace_with_defrag( | |||||
m_t_1->comp_node(), m_t_1_layout); | |||||
DeviceTensorND v_t = BlobManager::inst()->alloc_workspace_with_defrag( | |||||
v_t_1->comp_node(), v_t_1_layout); | |||||
DeviceTensorND new_param = BlobManager::inst()->alloc_workspace_with_defrag( | |||||
lamb_param->comp_node(), lamb_param_layout); | |||||
DnnOprCaller<megdnn::LAMBUpdate> caller{lamb_param->comp_node()}; | |||||
TensorLayout m_layout( | |||||
{caller.op->get_workspace_in_bytes( | |||||
m_t_1->layout(), v_t_1->layout(), lamb_param->layout(), | |||||
grad->layout(), m_t.layout(), v_t.layout(), new_param.layout())}, | |||||
dtype::Byte()); | |||||
auto dnn_workspace = caller.create_workspace(m_layout); | |||||
caller.op->param() = op.param(); | |||||
caller.op->exec( | |||||
m_t_1->dev_tensor().as_megdnn(), v_t_1->dev_tensor().as_megdnn(), | |||||
lamb_param->dev_tensor().as_megdnn(), grad->dev_tensor().as_megdnn(), | |||||
m_t.as_megdnn(), v_t.as_megdnn(), new_param.as_megdnn(), dnn_workspace); | |||||
return {Tensor::make(m_t), Tensor::make(v_t), Tensor::make(new_param)}; | |||||
} | |||||
OP_TRAIT_REG(LAMBUpdate, LAMBUpdate) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||||
.get_input_layout_constraint(get_input_layout_constraint) | |||||
.fallback(); | |||||
} // namespace lamb | |||||
} // namespace | |||||
} // namespace imperative | |||||
} // namespace mgb |
@@ -477,6 +477,9 @@ def Padding: MgbHashableOp<"Padding", [PaddingParam]>; | |||||
def LRN: MgbHashableOp<"LRN", [LRNParam]>; | def LRN: MgbHashableOp<"LRN", [LRNParam]>; | ||||
def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; | def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; | ||||
def LAMBUpdate: MgbHashableOp<"LAMBUpdate", [LAMBUpdateParam]>; | |||||
def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>; | def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>; | ||||
def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>; | def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>; | ||||