@@ -1442,6 +1442,39 @@ protected: | |||
void backward_check_exec(const TensorLayout& src, const TensorLayout& dst); | |||
}; | |||
class LAMBUpdate : public OperatorBase { | |||
DEF_OPR_PARAM(LAMBUpdate); | |||
// input=(m_t-1,v_t-1,lamb_param,grad) , output = (m_t,v_t,new_param) | |||
DEF_OPR_IMPL(LAMBUpdate, OperatorBase, 4, 3); | |||
public: | |||
virtual void exec( | |||
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, | |||
_megdnn_tensor_in lamb_param, _megdnn_tensor_in grad, | |||
_megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||
_megdnn_tensor_out new_param, _megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||
const TensorLayout& lamb_param, const TensorLayout& grad, | |||
const TensorLayout& m_t, const TensorLayout& v_t, | |||
const TensorLayout& new_param) = 0; | |||
void deduce_layout( | |||
const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||
const TensorLayout& lamb_param, const TensorLayout& grad, TensorLayout& m_t, | |||
TensorLayout& v_t, TensorLayout& new_param); | |||
protected: | |||
void check_exec( | |||
const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||
const TensorLayout& lamb_param, const TensorLayout& grad, | |||
const TensorLayout& m_t, const TensorLayout& v_t, | |||
const TensorLayout& new_param, size_t workspace_in_bytes); | |||
}; | |||
using LAMB = LAMBUpdate; | |||
} // namespace megdnn | |||
#include "megdnn/internal/opr_header_epilogue.h" | |||
@@ -36,13 +36,13 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||
add_enum(Doc('Format', 'convolution data/filter/output format; see ' | |||
':class:`RelayoutFormat` for more details'), | |||
'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', | |||
'NCHW44 = 7','NCHW44_DOT = 8', | |||
'NCHW44 = 7','NCHW44_DOT = 8', | |||
Doc('NCHW_WINOGRAD = 9', 'NCHW layout with weights tranformed by winograd'), | |||
Doc('NCHW88_WINOGRAD = 10', 'NCHW88 layout with weights tranformed by winograd'), | |||
Doc('NCHW44_WINOGRAD = 11', 'NCHW44 layout with weights tranformed by winograd'), | |||
Doc('NCHW4_NCHW32 = 12', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||
Doc('NCHW32_NCHW4 = 13', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||
Doc('NCHW4_NCHW = 14', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | |||
Doc('NCHW4_NCHW32 = 12', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||
Doc('NCHW32_NCHW4 = 13', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||
Doc('NCHW4_NCHW = 14', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | |||
Doc('NHWC_NCHW = 15', 'NHWC_NCHW means input tensors are nhwc layout, ' | |||
'output tensor is nchw layout'), | |||
Doc('NHWC_NCHW4_IC_SMALL = 16', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' | |||
@@ -96,9 +96,9 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||
add_enum(Doc('Format', 'convolution data/filter/output format; see ' | |||
':class:`RelayoutFormat` for more details'), | |||
'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', | |||
'NCHW44 = 7','NCHW44_DOT = 8', | |||
Doc('NCHW4_NCHW32 = 9', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||
Doc('NCHW32_NCHW4 = 10', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||
'NCHW44 = 7','NCHW44_DOT = 8', | |||
Doc('NCHW4_NCHW32 = 9', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||
Doc('NCHW32_NCHW4 = 10', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||
Doc('NCHW4_NCHW = 11', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | |||
Doc('NHWC_NCHW = 12', 'NHWC_NCHW means input tensors are nhwc layout, ' | |||
'output tensor is nchw layout'), | |||
@@ -107,9 +107,9 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||
Doc('NCHW_NCHW4_IC_SMALL = 14', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, ' | |||
'output tensor is nchw4 layout, padding c=4'), | |||
Doc('CHWN4 = 15', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' | |||
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'), | |||
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'), | |||
Doc('NCHW64 = 16', 'NCHW64 is designed for convolution implementation to utilizing TensorCore ' | |||
'instructions for 4-bit integers on Nvidia platforms'), | |||
'instructions for 4-bit integers on Nvidia platforms'), | |||
Doc('NCHW4_NHWC = 17', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout')). | |||
add_enum_alias('ComputeMode', 'ConvolutionV1',name_field='compute_mode') | |||
) | |||
@@ -1038,10 +1038,10 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||
'NCHW_NCHW4 = 24', | |||
'NCHW4_NCHW = 25', | |||
'NCHW_NCHW4_WEIGHT = 26', | |||
'NCHW_NCHW64 = 27', | |||
'NCHW64_NCHW = 28', | |||
'NCHW_NHWC = 29', | |||
'NHWC_NCHW = 30', | |||
'NCHW_NCHW64 = 27', | |||
'NCHW64_NCHW = 28', | |||
'NCHW_NHWC = 29', | |||
'NHWC_NCHW = 30', | |||
'NHWCD4I_NHWC = 31', | |||
) | |||
) | |||
@@ -1264,3 +1264,14 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | |||
add_fields('float32', Doc('dropout', 'If introduce a Dropout layer on the outputs of each LSTM layer'), '0.f'). | |||
add_enum_alias('FwdMode', 'BN', name_field='fwd_mode') | |||
) | |||
(pdef('LAMBUpdate'). | |||
add_fields('float32', Doc('beta_1', 'beta_1 paramter of lamb'), '1.f'). | |||
add_fields('float32', Doc('beta_2', 'beta_2 paramter of lamb'), '1.f'). | |||
add_fields('float32', Doc('step', 'training step'), '1.f'). | |||
add_fields('float32', Doc('lr', 'learning rate'), '1.f'). | |||
add_fields('float32', Doc('weight_decay', 'weight decay to adjust learning rate'), '1.f'). | |||
add_fields('float32', Doc('eps', 'eps to multi'), '1.f'). | |||
add_fields('bool', Doc('bias_correction', 'whether correct bias'), 'true'). | |||
add_fields('bool', Doc('always_adapt', 'apply adaptive lr to 0.0'), 'false') | |||
) |
@@ -209,6 +209,7 @@ private: | |||
cb(RNN) \ | |||
cb(RNNBackward) \ | |||
cb(LSTM) \ | |||
cb(LAMBUpdate) \ | |||
cb(LSTMBackward) \ | |||
cb(SoftmaxForward) \ | |||
cb(SoftmaxBackward) | |||
@@ -0,0 +1,25 @@ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void LAMBUpdate::deduce_layout( | |||
const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||
const TensorLayout& lamb_param, const TensorLayout& grad, TensorLayout& m_t, | |||
TensorLayout& v_t, TensorLayout& new_param) { | |||
m_t = TensorLayout(m_t_1); | |||
v_t = TensorLayout(v_t_1); | |||
new_param = TensorLayout(lamb_param); | |||
MEGDNN_MARK_USED_VAR(grad); | |||
} | |||
void LAMBUpdate::check_exec( | |||
const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||
const TensorLayout& lamb_param, const TensorLayout& grad, | |||
const TensorLayout& m_t, const TensorLayout& v_t, const TensorLayout& new_param, | |||
size_t workspace_in_bytes) { | |||
auto required_workspace_in_bytes = | |||
get_workspace_in_bytes(m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
} // namespace megdnn |
@@ -127,6 +127,7 @@ DEF(LSQBackward, 7, true, false); | |||
DEF(Fill, 1, true, false); | |||
DEF(LayerNormForward, 6, true, true); | |||
DEF(LayerNormBackward, 8, true, true); | |||
DEF(LAMBUpdate, 7, true, true); | |||
DEF(DropoutForward, 3, true, true); | |||
DEF(DropoutBackward, 3, true, true); | |||
DEF(RNNCellForward, 7, true, true); | |||
@@ -35,6 +35,7 @@ | |||
#include "src/cuda/images2neibs/opr_impl.h" | |||
#include "src/cuda/indexing_multi_axis_vec/opr_impl.h" | |||
#include "src/cuda/indexing_one_hot/opr_impl.h" | |||
#include "src/cuda/lamb/opr_impl.h" | |||
#include "src/cuda/layer_norm/opr_impl.h" | |||
#include "src/cuda/linspace/opr_impl.h" | |||
#include "src/cuda/local/opr_impl.h" | |||
@@ -210,6 +211,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingBackward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormBackward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LAMBUpdate); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxForward); | |||
@@ -0,0 +1,102 @@ | |||
#include <thrust/device_vector.h> | |||
#include <thrust/pair.h> | |||
#include <thrust/transform_reduce.h> | |||
#include <thrust/tuple.h> | |||
#include <cfloat> | |||
#include "megdnn/arch.h" | |||
#include "megdnn/dtype.h" | |||
#include "src/cuda/cuda_shfl_compat.cuh" | |||
#include "src/cuda/lamb/lamb_cuda.cuh" | |||
#include "src/cuda/utils.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace lamb { | |||
template <typename T> | |||
struct square { | |||
__host__ __device__ T operator()(const T& x) const { return x * x; } | |||
}; | |||
template <typename T, typename T_ACC> | |||
__global__ void update_kernal_1( | |||
T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, | |||
T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, | |||
float weight_decay, float eps, bool bias_correction, bool always_adapt, | |||
size_t total_nr_elem) { | |||
size_t idx = threadIdx.x + blockIdx.x * blockDim.x; | |||
T_ACC bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1, | |||
bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1; | |||
if (idx < total_nr_elem) { | |||
m_t[idx] = beta_1 * m_t_1[idx] + (1 - beta_1) * static_cast<T_ACC>(grad[idx]); | |||
v_t[idx] = beta_2 * v_t_1[idx] + | |||
(1 - beta_2) * std::pow(static_cast<T_ACC>(grad[idx]), 2); | |||
rt[idx] = (m_t[idx] / bc_1) / (std::sqrt(v_t[idx] / bc_2) + eps); | |||
if (weight_decay != 0) { | |||
rt[idx] += lamb_param[idx] * weight_decay; | |||
} | |||
} | |||
} | |||
template <typename T, typename T_ACC> | |||
__global__ void update_kernal_2( | |||
T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, | |||
T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, | |||
float weight_decay, float eps, bool bias_correction, bool always_adapt, | |||
size_t total_nr_elem, T_ACC trust_ratio) { | |||
size_t idx = threadIdx.x + blockIdx.x * blockDim.x; | |||
T_ACC bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1, | |||
bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1; | |||
if (idx < total_nr_elem) { | |||
rt[idx] = (m_t[idx] / bc_1) / (std::sqrt(v_t[idx] / bc_2) + eps); | |||
if (weight_decay != 0) { | |||
rt[idx] += lamb_param[idx] * weight_decay; | |||
} | |||
new_param[idx] = lamb_param[idx] - lr * trust_ratio * rt[idx]; | |||
} | |||
} | |||
template <typename T, typename T_ACC> | |||
void update( | |||
T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, | |||
T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, | |||
float weight_decay, float eps, bool bias_correction, bool always_adapt, | |||
size_t total_nr_elem, cudaStream_t stream) { | |||
size_t NR_BLOCKS = DIVUP(total_nr_elem, NR_THREADS); | |||
update_kernal_1<T, T_ACC><<<NR_BLOCKS, NR_THREADS, 0, stream>>>( | |||
m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, rt, beta_1, beta_2, | |||
step, lr, weight_decay, eps, bias_correction, always_adapt, total_nr_elem); | |||
after_kernel_launch(); | |||
thrust::device_ptr<T_ACC> lamb_param_ptr(lamb_param); | |||
thrust::device_ptr<T_ACC> rt_ptr(rt); | |||
square<T_ACC> unary_op; | |||
thrust::plus<T_ACC> binary_op; | |||
T_ACC p_norm = std::sqrt(thrust::transform_reduce( | |||
lamb_param_ptr, lamb_param_ptr + total_nr_elem, unary_op, 0.f, binary_op)); | |||
T_ACC d_norm = std::sqrt(thrust::transform_reduce( | |||
rt_ptr, rt_ptr + total_nr_elem, unary_op, 0.f, binary_op)); | |||
T_ACC trust_ratio = 1; | |||
if ((always_adapt || weight_decay > 0) && p_norm > 0 && d_norm > 0) { | |||
trust_ratio = p_norm / d_norm; | |||
} | |||
update_kernal_2<T, T_ACC><<<NR_BLOCKS, NR_THREADS, 0, stream>>>( | |||
m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, rt, beta_1, beta_2, | |||
step, lr, weight_decay, eps, bias_correction, always_adapt, total_nr_elem, | |||
trust_ratio); | |||
after_kernel_launch(); | |||
} | |||
#define INST(T, T_ACC) \ | |||
template void update<T, T_ACC>( \ | |||
T_ACC*, T_ACC*, T_ACC*, T*, T_ACC*, T_ACC*, T_ACC*, T_ACC*, float, float, \ | |||
float, float, float, float, bool, bool, size_t, cudaStream_t); | |||
INST(dt_float32, dt_float32) | |||
INST(dt_float16, dt_float32) | |||
INST(dt_bfloat16, dt_float32) | |||
#undef INST | |||
} // namespace lamb | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,17 @@ | |||
#pragma once | |||
#include <cuda_runtime_api.h> | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace lamb { | |||
template <typename T, typename T_ACC> | |||
void update( | |||
T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, | |||
T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, | |||
float weight_decay, float eps, bool bias_correction, bool always_adapt, | |||
size_t total_nr_elem, cudaStream_t stream); | |||
} // namespace lamb | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,45 @@ | |||
#include "src/cuda/lamb/opr_impl.h" | |||
#include "./lamb_cuda.cuh" | |||
#include "src/cuda/utils.h" | |||
#include <cmath> | |||
#include <functional> | |||
#include <numeric> | |||
namespace megdnn { | |||
namespace cuda { | |||
void LAMBUpdateImpl::exec( | |||
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param, | |||
_megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||
_megdnn_tensor_out new_param, _megdnn_workspace workspace) { | |||
auto p = param(); | |||
float beta_1 = p.beta_1; | |||
float beta_2 = p.beta_2; | |||
float step = p.step; | |||
float lr = p.lr; | |||
float weight_decay = p.weight_decay; | |||
float eps = p.eps; | |||
bool bias_correction = p.bias_correction; | |||
bool always_adapt = p.always_adapt; | |||
size_t total_elem = lamb_param.layout.total_nr_elems(); | |||
auto stream = cuda_stream(handle()); | |||
using namespace ::megdnn::cuda::lamb; | |||
#define cb(DType) \ | |||
if (grad.layout.dtype == DType()) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
using T_ACC = float; \ | |||
update<T, T_ACC>( \ | |||
m_t_1.ptr<T_ACC>(), v_t_1.ptr<T_ACC>(), lamb_param.ptr<T_ACC>(), \ | |||
grad.ptr<T>(), m_t.ptr<T_ACC>(), v_t.ptr<T_ACC>(), \ | |||
new_param.ptr<T_ACC>(), workspace.ptr<T_ACC>(), beta_1, beta_2, step, \ | |||
lr, weight_decay, eps, bias_correction, always_adapt, total_elem, \ | |||
stream); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
#undef cb | |||
megdnn_throw("bad dtype"); | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,25 @@ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/cuda/cudnn_wrapper.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class LAMBUpdateImpl final : public LAMBUpdate { | |||
public: | |||
using LAMBUpdate::LAMBUpdate; | |||
void exec( | |||
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, | |||
_megdnn_tensor_in lamb_param, _megdnn_tensor_in grad, | |||
_megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||
_megdnn_tensor_out new_param, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||
const TensorLayout& lamb_param, const TensorLayout& grad, | |||
const TensorLayout& m_t, const TensorLayout& v_t, | |||
const TensorLayout& new_param) override { | |||
return m_t.access_bytes(); | |||
}; | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -37,6 +37,7 @@ | |||
#include "src/naive/images2neibs/opr_impl.h" | |||
#include "src/naive/indexing_multi_axis_vec/opr_impl.h" | |||
#include "src/naive/indexing_one_hot/opr_impl.h" | |||
#include "src/naive/lamb/opr_impl.h" | |||
#include "src/naive/layer_norm/opr_impl.h" | |||
#include "src/naive/linspace/opr_impl.h" | |||
#include "src/naive/local/opr_impl.h" | |||
@@ -0,0 +1,89 @@ | |||
#include "src/naive/lamb/opr_impl.h" | |||
#include <cmath> | |||
#include <functional> | |||
#include <numeric> | |||
#include "src/common/utils.h" | |||
#include "src/naive/handle.h" | |||
using namespace megdnn; | |||
using namespace naive; | |||
namespace { | |||
using Param = megdnn::LAMBUpdate::Param; | |||
template <typename T, typename T_ACC = float> | |||
void update( | |||
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param, | |||
_megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||
_megdnn_tensor_out new_param, const Param& param) { | |||
float beta_1 = param.beta_1; | |||
float beta_2 = param.beta_2; | |||
float step = param.step; | |||
float lr = param.lr; | |||
float weight_decay = param.weight_decay; | |||
float eps = param.eps; | |||
bool bias_correction = param.bias_correction; | |||
bool always_adapt = param.always_adapt; | |||
size_t total_elem = lamb_param.layout.total_nr_elems(); | |||
T_ACC mt, vt, bc_1, bc_2, rt, d_norm = 0; | |||
bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1; | |||
bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1; | |||
for (size_t i = 0; i < total_elem; i++) { | |||
mt = m_t.ptr<T_ACC>()[i] = beta_1 * m_t_1.ptr<T_ACC>()[i] + | |||
(1 - beta_1) * static_cast<T_ACC>(grad.ptr<T>()[i]); | |||
vt = v_t.ptr<T_ACC>()[i] = | |||
beta_2 * v_t_1.ptr<T_ACC>()[i] + | |||
(1 - beta_2) * std::pow(static_cast<T_ACC>(grad.ptr<T>()[i]), 2); | |||
rt = (mt / bc_1) / (sqrt(vt / bc_2) + eps); | |||
if (weight_decay != 0) { | |||
rt += lamb_param.ptr<T_ACC>()[i] * weight_decay; | |||
} | |||
d_norm += rt * rt; | |||
} | |||
d_norm = sqrt(d_norm); | |||
auto get_norm = [=](_megdnn_tensor_in norm) -> T_ACC { | |||
return sqrt(std::accumulate( | |||
norm.ptr<T_ACC>(), norm.ptr<T_ACC>() + total_elem, 0, | |||
[](T_ACC t1, T_ACC t2) -> T_ACC { return t1 + t2 * t2; })); | |||
}; | |||
T_ACC p_norm = get_norm(lamb_param), trust_ratio = 1; | |||
if ((always_adapt || weight_decay > 0) && p_norm > 0 && d_norm > 0) { | |||
trust_ratio = p_norm / d_norm; | |||
} | |||
for (size_t i = 0; i < total_elem; i++) { | |||
mt = m_t.ptr<T_ACC>()[i]; | |||
vt = v_t.ptr<T_ACC>()[i]; | |||
rt = (mt / bc_1) / (sqrt(vt / bc_2) + eps); | |||
if (weight_decay != 0) { | |||
rt += lamb_param.ptr<T_ACC>()[i] * weight_decay; | |||
} | |||
new_param.ptr<T_ACC>()[i] = lamb_param.ptr<T_ACC>()[i] - lr * trust_ratio * rt; | |||
} | |||
} | |||
} // namespace | |||
namespace megdnn { | |||
namespace naive { | |||
void LAMBUpdateImpl::exec( | |||
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param, | |||
_megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||
_megdnn_tensor_out new_param, _megdnn_workspace workspace) { | |||
check_exec( | |||
m_t_1.layout, v_t_1.layout, lamb_param.layout, grad.layout, m_t.layout, | |||
v_t.layout, new_param.layout, workspace.size); | |||
#define cb(DType) \ | |||
if (grad.layout.dtype == DType()) { \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(update<typename DTypeTrait<DType>::ctype>( \ | |||
m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, param())); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
#undef cb | |||
megdnn_throw("bad dtype"); | |||
} | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,34 @@ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class LAMBUpdateImpl final : public LAMBUpdate { | |||
public: | |||
using LAMBUpdate::LAMBUpdate; | |||
void exec( | |||
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, | |||
_megdnn_tensor_in lamb_param, _megdnn_tensor_in grad, | |||
_megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||
_megdnn_tensor_out new_param, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||
const TensorLayout& lamb_param, const TensorLayout& grad, | |||
const TensorLayout& m_t, const TensorLayout& v_t, | |||
const TensorLayout& new_param) override { | |||
MEGDNN_MARK_USED_VAR(m_t_1); | |||
MEGDNN_MARK_USED_VAR(v_t_1); | |||
MEGDNN_MARK_USED_VAR(lamb_param); | |||
MEGDNN_MARK_USED_VAR(grad); | |||
MEGDNN_MARK_USED_VAR(m_t); | |||
MEGDNN_MARK_USED_VAR(v_t); | |||
MEGDNN_MARK_USED_VAR(new_param); | |||
return 0; | |||
}; | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,36 @@ | |||
#pragma once | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/opr_param_defs.h" | |||
namespace megdnn { | |||
namespace test { | |||
namespace lamb { | |||
struct TestArg { | |||
param::LAMBUpdate param; | |||
TensorShape src; | |||
TestArg(param::LAMBUpdate param, TensorShape src) : param(param), src(src) {} | |||
}; | |||
inline std::vector<TestArg> get_args() { | |||
std::vector<TestArg> args; | |||
param::LAMBUpdate cur_param; | |||
cur_param.beta_1 = 0.9; | |||
cur_param.beta_2 = 0.999; | |||
cur_param.eps = 1e-8; | |||
cur_param.weight_decay = 0; | |||
cur_param.lr = 6.25e-5; | |||
cur_param.bias_correction = true; | |||
cur_param.always_adapt = false; | |||
args.emplace_back( | |||
cur_param, TensorShape{ | |||
1280, | |||
}); | |||
args.emplace_back(cur_param, TensorShape{1280, 1280}); | |||
args.emplace_back(cur_param, TensorShape{1280, 3, 224, 224}); | |||
return args; | |||
} | |||
} // namespace lamb | |||
} // namespace test | |||
} // namespace megdnn |
@@ -0,0 +1,44 @@ | |||
#include "test/cuda/fixture.h" | |||
#include "test/common/checker.h" | |||
#include "test/common/rng.h" | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(CUDA, LAMBUpdate) { | |||
LAMBUpdate::Param param; | |||
param.beta_1 = 0.9; | |||
param.beta_2 = 0.999; | |||
param.eps = 1e-5; | |||
param.weight_decay = 0.4; | |||
param.lr = 1e-3; | |||
param.step = 1; | |||
param.bias_correction = true; | |||
param.always_adapt = false; | |||
Checker<LAMBUpdate> checker(handle_cuda()); | |||
checker.set_epsilon(1e-3); | |||
UniformFloatRNG rng0(0, 1); | |||
auto run = [&](DType d) { | |||
checker.set_param(param) | |||
.set_rng(0, &rng0) | |||
.set_rng(1, &rng0) | |||
.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||
.set_dtype(2, dtype::Float32()) | |||
.set_dtype(3, d) | |||
.set_dtype(4, dtype::Float32()) | |||
.set_dtype(5, dtype::Float32()) | |||
.set_dtype(6, dtype::Float32()) | |||
.execs({{2}, {2}, {2}, {2}, {}, {}, {}}); | |||
}; | |||
run(dtype::Float32()); | |||
run(dtype::Float16()); | |||
run(dtype::BFloat16()); | |||
} | |||
} // namespace test | |||
} // namespace megdnn |
@@ -0,0 +1,33 @@ | |||
#include "test/common/lamb.h" | |||
#include "megdnn/dtype.h" | |||
#include "megdnn/oprs.h" | |||
#include "test/common/checker.h" | |||
#include "test/naive/fixture.h" | |||
using namespace megdnn; | |||
using namespace test; | |||
TEST_F(NAIVE, LAMBUpdate) { | |||
Checker<LAMBUpdate> checker(handle(), false); | |||
LAMBUpdate::Param param; | |||
param.beta_1 = 0; | |||
param.beta_2 = 0; | |||
param.eps = 0; | |||
param.weight_decay = 0; | |||
param.lr = 1; | |||
param.step = 1; | |||
param.bias_correction = true; | |||
param.always_adapt = false; | |||
TensorND m_t_1 = TensorValue({2}, dtype::Float32(), {1, 1}); | |||
TensorND v_t_1 = TensorValue({2}, dtype::Float32(), {1, 1}); | |||
TensorND param_lamb = TensorValue({2}, dtype::Float32(), {1, 1}); | |||
TensorND grad = TensorValue({2}, dtype::Float16(), {1, 1}); | |||
TensorND m_t = TensorValue({2}, dtype::Float32(), {1, 1}); | |||
TensorND v_t = TensorValue({2}, dtype::Float32(), {1, 1}); | |||
TensorND new_param = TensorValue({2}, dtype::Float32(), {0, 0}); | |||
checker.set_param(param).exect( | |||
Testcase{m_t_1, v_t_1, param_lamb, grad, {}, {}, {}}, | |||
Testcase{{}, {}, {}, {}, m_t, v_t, new_param}); | |||
} |
@@ -4,6 +4,7 @@ from .adagrad import Adagrad | |||
from .adam import Adam | |||
from .adamw import AdamW | |||
from .clip_grad import * | |||
from .lamb import LAMB, LAMBFp16 | |||
from .lr_scheduler import LRScheduler | |||
from .multi_step_lr import MultiStepLR | |||
from .optimizer import Optimizer | |||
@@ -0,0 +1,160 @@ | |||
# Copyright (c) 2020 Ross Wightman | |||
# This file has been modified by Megvii ("Megvii Modifications"). | |||
# All Megvii Modifications are Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
"""LAMB optimizer | |||
References: https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py | |||
""" | |||
import os | |||
from typing import Iterable, Tuple, Union | |||
from megengine.core._imperative_rt.core2 import apply | |||
from megengine.core.ops.builtin import LAMBUpdate | |||
from .. import Parameter, tensor | |||
from ..functional import sum | |||
from ..functional.inplace import _inplace_add_ | |||
from .optimizer import Optimizer | |||
class LAMB(Optimizer): | |||
r"""Implements LAMB algorithm. | |||
LAMB is proposed in `"Large Batch Optimization for Deep Learning: Training BERT in 76 minutes" | |||
<https://arxiv.org/abs/1904.00962>`_. | |||
Args: | |||
params: iterable of parameters to optimize or dicts defining parameter groups. | |||
lr: learning rate. | |||
betas: coefficients used for computing running averages of gradient and its square. | |||
Default: ``(0.9, 0.999)`` | |||
eps: term added to the denominator to improve numerical stability. Default: ``1e-8`` | |||
bias_correction: enables bias correction by ``1 - beta ** step``. Default: ``True`` | |||
weight_decay: weight decay (L2 penalty). Default: ``0.0`` | |||
always_adapt: apply adaptive lr to ``0.0`` weight decay parameter. Default: ``False`` | |||
""" | |||
def __init__( | |||
self, | |||
params: Union[Iterable[Parameter], dict], | |||
lr: float, | |||
betas: Tuple[float, float] = (0.9, 0.999), | |||
eps: float = 1e-8, | |||
bias_correction: bool = True, | |||
weight_decay: float = 0.0, | |||
always_adapt: bool = False, | |||
): | |||
if lr < 0.0: | |||
raise ValueError("Invalid learning rate: {}".format(lr)) | |||
if weight_decay < 0.0: | |||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | |||
if not 0.0 <= betas[0] < 1.0: | |||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | |||
if not 0.0 <= betas[1] < 1.0: | |||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | |||
defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) | |||
super().__init__(params, defaults) | |||
self.bias_correction = bias_correction | |||
self.always_adapt = always_adapt | |||
self._disable_type_convert = True | |||
def _create_state(self, param_group): | |||
for param in param_group["params"]: | |||
self._add_state(param, "exp_avg") | |||
self._add_state(param, "exp_avg_sq") | |||
self._add_state(param, "step", initializer=0.0, dtype="float32") | |||
def _updates(self, param_group): | |||
lr = param_group["lr"] | |||
weight_decay = param_group["weight_decay"] | |||
eps = param_group["eps"] | |||
beta0, beta1 = param_group["betas"] | |||
# since `conver_inputs` is disabled for param updates, | |||
# scalar should be explicitly tansforred to tensor | |||
c1 = tensor(1.0) | |||
for param in param_group["params"]: | |||
if param.grad is None: | |||
continue | |||
grad = param.grad | |||
states = self._state[param] | |||
step, exp_avg, exp_avg_sq = ( | |||
states["step"], | |||
states["exp_avg"], | |||
states["exp_avg_sq"], | |||
) | |||
step += c1 | |||
op = LAMBUpdate( | |||
beta0, | |||
beta1, | |||
int(step), | |||
lr, | |||
weight_decay, | |||
eps, | |||
self.bias_correction, | |||
self.always_adapt, | |||
) | |||
new_exp_avg, new_exp_avg_sq, new_param = apply( | |||
op, exp_avg, exp_avg_sq, param, grad | |||
) | |||
param._reset(new_param) | |||
exp_avg._reset(new_exp_avg) | |||
exp_avg_sq._reset(new_exp_avg_sq) | |||
class LAMBFp16(LAMB): | |||
def _create_state(self, param_group): | |||
for param in param_group["params"]: | |||
self._add_state(param, "exp_avg", dtype="float32") | |||
self._add_state(param, "exp_avg_sq", dtype="float32") | |||
self._add_state(param, "step", initializer=0.0, dtype="float32") | |||
self._state[param]["param_fp32"] = param.astype("float32") | |||
def _updates(self, param_group): | |||
lr = param_group["lr"] | |||
weight_decay = param_group["weight_decay"] | |||
eps = param_group["eps"] | |||
beta0, beta1 = param_group["betas"] | |||
c1 = tensor(1.0) | |||
for param in param_group["params"]: | |||
if param.grad is None: | |||
continue | |||
grad = param.grad | |||
states = self._state[param] | |||
step, exp_avg, exp_avg_sq = ( | |||
states["step"], | |||
states["exp_avg"], | |||
states["exp_avg_sq"], | |||
) | |||
step += c1 | |||
fp32_param = states["param_fp32"] | |||
op = LAMBUpdate( | |||
beta0, | |||
beta1, | |||
step, | |||
lr, | |||
weight_decay, | |||
eps, | |||
self.bias_correction, | |||
self.always_adapt, | |||
) | |||
new_exp_avg, new_exp_avg_sq, new_param = apply( | |||
op, exp_avg, exp_avg_sq, fp32_param, grad | |||
) | |||
fp32_param._reset(new_param) | |||
param._reset(new_param.astype("float16")) | |||
exp_avg._reset(new_exp_avg) | |||
exp_avg_sq._reset(new_exp_avg_sq) |
@@ -0,0 +1,85 @@ | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.autodiff as ad | |||
import megengine.functional as F | |||
import megengine.module as M | |||
import megengine.optimizer as optim | |||
from megengine import tensor | |||
from megengine.core._imperative_rt.core2 import apply | |||
from megengine.core.ops.builtin import LAMBUpdate | |||
def lamb_update( | |||
param_group, step, exp_avg, exp_avg_sq, param, grad, bias_correction, always_adapt | |||
): | |||
lr = param_group["lr"] | |||
weight_decay = param_group["weight_decay"] | |||
eps = param_group["eps"] | |||
beta0, beta1 = param_group["betas"] | |||
# since `conver_inputs` is disabled for param updates, | |||
# scalar should be explicitly tansforred to tensor | |||
_lr, _neg_lr = map(tensor, (lr, -lr)) | |||
_weight_decay = tensor(weight_decay) | |||
_eps = tensor(eps) | |||
_beta0, _beta1 = map(tensor, (beta0, beta1)) | |||
c1, c05, c0 = map(tensor, (1.0, 0.5, 0.0)) | |||
def norm(vec): | |||
return sum(vec * vec) ** c05 | |||
p_norm = norm(param.flatten()) | |||
# step = step + c1 | |||
step += c1 | |||
# exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0) | |||
exp_avg *= _beta0 | |||
exp_avg += grad * (c1 - _beta0) | |||
# exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad) | |||
exp_avg_sq *= _beta1 | |||
exp_avg_sq += (c1 - _beta1) * (grad * grad) | |||
bias_correction1 = c1 - _beta0 ** step if bias_correction else c1 | |||
bias_correction2 = c1 - _beta1 ** step if bias_correction else c1 | |||
delta = (exp_avg / bias_correction1) / ( | |||
(exp_avg_sq / bias_correction2) ** c05 + _eps | |||
) | |||
if weight_decay != 0.0: | |||
delta += param * _weight_decay | |||
d_norm = norm(delta.flatten()) | |||
trust_ratio = ( | |||
p_norm / d_norm | |||
if (always_adapt or weight_decay > 0) and p_norm > c0 and d_norm > c0 | |||
else c1 | |||
) | |||
new_param = param - _lr * trust_ratio * delta | |||
return exp_avg, exp_avg_sq, new_param | |||
def test_lamb(): | |||
op = LAMBUpdate(0.9, 0.999, 1, 1e-3, 0.4, 1e-8, True, False) | |||
m_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32) | |||
v_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32) | |||
params = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32) | |||
grad = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float16) | |||
(new_m_t, new_v_t, new_param) = apply(op, m_t_1, v_t_1, params, grad) | |||
param_group = { | |||
"betas": (0.9, 0.999), | |||
"step": 1, | |||
"lr": 1e-3, | |||
"weight_decay": 0.4, | |||
"eps": 1e-8, | |||
} | |||
gt_m_t, gt_v_t, gt_new_param = lamb_update( | |||
param_group, 1, m_t_1, v_t_1, params, grad, True, False | |||
) | |||
np.testing.assert_allclose(new_m_t.numpy(), gt_m_t.numpy(), atol=1e-2) | |||
np.testing.assert_allclose(new_v_t.numpy(), gt_v_t.numpy(), atol=1e-2) | |||
np.testing.assert_allclose(new_param.numpy(), gt_new_param.numpy(), atol=1e-2) |
@@ -0,0 +1,82 @@ | |||
#include "megbrain/imperative/opr_utility.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "../blob_manager_impl.h" | |||
#include "../dnn_op_helper.h" | |||
#include "../op_trait.h" | |||
namespace mgb { | |||
namespace imperative { | |||
namespace { | |||
namespace lamb { | |||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||
return layout_checker; | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||
mgb_assert(input_descs.size() == 4, "IndexingOneHot expects 4inputs"); | |||
auto comp_node = input_descs[0].comp_node; | |||
auto comp_node1 = input_descs[1].comp_node; | |||
auto comp_node2 = input_descs[2].comp_node; | |||
TensorLayout m_t_1 = input_descs[0].layout, v_t_1 = input_descs[1].layout, | |||
lamb_param = input_descs[2].layout, grad = input_descs[3].layout; | |||
TensorLayout new_param = lamb_param, m_t = m_t_1, v_t = v_t_1; | |||
return {{{m_t, comp_node}, {v_t, comp_node1}, {new_param, comp_node2}}, true}; | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto&& op = def.cast_final_safe<LAMBUpdate>(); | |||
auto&& m_t_1 = inputs[0]; | |||
auto&& v_t_1 = inputs[1]; | |||
auto&& lamb_param = inputs[2]; | |||
auto&& grad = inputs[3]; | |||
TensorLayout m_t_1_layout{m_t_1->layout()}; | |||
TensorLayout v_t_1_layout{v_t_1->layout()}; | |||
TensorLayout lamb_param_layout{lamb_param->layout()}; | |||
DeviceTensorND m_t = BlobManager::inst()->alloc_workspace_with_defrag( | |||
m_t_1->comp_node(), m_t_1_layout); | |||
DeviceTensorND v_t = BlobManager::inst()->alloc_workspace_with_defrag( | |||
v_t_1->comp_node(), v_t_1_layout); | |||
DeviceTensorND new_param = BlobManager::inst()->alloc_workspace_with_defrag( | |||
lamb_param->comp_node(), lamb_param_layout); | |||
DnnOprCaller<megdnn::LAMBUpdate> caller{lamb_param->comp_node()}; | |||
TensorLayout m_layout( | |||
{caller.op->get_workspace_in_bytes( | |||
m_t_1->layout(), v_t_1->layout(), lamb_param->layout(), | |||
grad->layout(), m_t.layout(), v_t.layout(), new_param.layout())}, | |||
dtype::Byte()); | |||
auto dnn_workspace = caller.create_workspace(m_layout); | |||
caller.op->param() = op.param(); | |||
caller.op->exec( | |||
m_t_1->dev_tensor().as_megdnn(), v_t_1->dev_tensor().as_megdnn(), | |||
lamb_param->dev_tensor().as_megdnn(), grad->dev_tensor().as_megdnn(), | |||
m_t.as_megdnn(), v_t.as_megdnn(), new_param.as_megdnn(), dnn_workspace); | |||
return {Tensor::make(m_t), Tensor::make(v_t), Tensor::make(new_param)}; | |||
} | |||
OP_TRAIT_REG(LAMBUpdate, LAMBUpdate) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||
.get_input_layout_constraint(get_input_layout_constraint) | |||
.fallback(); | |||
} // namespace lamb | |||
} // namespace | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -477,6 +477,9 @@ def Padding: MgbHashableOp<"Padding", [PaddingParam]>; | |||
def LRN: MgbHashableOp<"LRN", [LRNParam]>; | |||
def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; | |||
def LAMBUpdate: MgbHashableOp<"LAMBUpdate", [LAMBUpdateParam]>; | |||
def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>; | |||
def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>; | |||