GitOrigin-RevId: 01b5324392
tags/v1.9.0
@@ -28,6 +28,16 @@ void BNForward::check_exec( | |||
const TensorLayout& variance, const TensorLayout& batch_mean, | |||
const TensorLayout& batch_inv_variance, const TensorLayout& dst, | |||
size_t workspace_in_bytes, size_t reserve_in_bytes) { | |||
// moving some python assert to dnn to decrease the assert overhead | |||
megdnn_assert( | |||
src.ndim == 4, | |||
"ndim of the input tensor for batch_norm should be 4, but you give %zu", | |||
src.ndim); | |||
megdnn_assert(bn_scale.ndim == 4, "expect 4, get %zu\n", bn_scale.ndim); | |||
megdnn_assert(bn_bias.ndim == 4, "expect 4, get %zu\n", bn_bias.ndim); | |||
megdnn_assert_eq_layout(bn_scale, bn_bias); | |||
megdnn_assert_eq_layout(batch_mean, batch_inv_variance); | |||
megdnn_assert_contiguous(src); | |||
megdnn_assert_eq_layout(src, dst); | |||
megdnn_assert_eq_layout(bn_scale, bn_bias); | |||
@@ -58,16 +58,19 @@ class autocast: | |||
self._origin_low = None | |||
def __enter__(self): | |||
self._origin_enabled, amp._enabled = amp._enabled, self.enabled | |||
self._origin_high = amp._high_prec_dtype | |||
amp._high_prec_dtype = self.high_prec_dtype | |||
self._origin_low = amp._low_prec_dtype | |||
amp._low_prec_dtype = self.low_prec_dtype | |||
self._origin_enabled = amp._enabled | |||
self._origin_high = amp._get_amp_high_prec_dtype() | |||
self._origin_low = amp._get_amp_low_prec_dtype() | |||
amp._enabled = self.enabled | |||
amp._set_amp_dtype_autocast(self.enabled) | |||
amp._set_amp_high_prec_dtype(self.high_prec_dtype) | |||
amp._set_amp_low_prec_dtype(self.low_prec_dtype) | |||
def __exit__(self, *args): | |||
amp._enabled = self._origin_enabled | |||
amp._high_prec_dtype = self._origin_high | |||
amp._low_prec_dtype = self._origin_low | |||
amp._set_amp_dtype_autocast(self._origin_enabled) | |||
amp._set_amp_high_prec_dtype(self._origin_high) | |||
amp._set_amp_low_prec_dtype(self._origin_low) | |||
def __call__(self, func): | |||
@functools.wraps(func) | |||
@@ -5,9 +5,18 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .._imperative_rt.core2 import ( | |||
_get_amp_dtype_autocast, | |||
_get_amp_high_prec_dtype, | |||
_get_amp_low_prec_dtype, | |||
_set_amp_dtype_autocast, | |||
_set_amp_high_prec_dtype, | |||
_set_amp_low_prec_dtype, | |||
) | |||
_enabled = False | |||
_high_prec_dtype = "float32" | |||
_low_prec_dtype = "float16" | |||
_set_amp_dtype_autocast(_enabled) | |||
@property | |||
@@ -28,6 +37,7 @@ def enabled(mod): | |||
def enabled(mod, enabled: bool): | |||
global _enabled | |||
_enabled = enabled | |||
_set_amp_dtype_autocast(_enabled) | |||
@property | |||
@@ -42,13 +52,12 @@ def high_prec_dtype(mod): | |||
import megengine as mge | |||
mge.amp.high_prec_dtype = "float32" | |||
""" | |||
return _high_prec_dtype | |||
return _get_amp_high_prec_dtype() | |||
@high_prec_dtype.setter | |||
def high_prec_dtype(mod, dtype: str): | |||
global _high_prec_dtype | |||
_high_prec_dtype = dtype | |||
_set_amp_high_prec_dtype(dtype) | |||
@property | |||
@@ -63,10 +72,9 @@ def low_prec_dtype(mod): | |||
import megengine as mge | |||
mge.amp.low_prec_dtype = "float16" | |||
""" | |||
return _low_prec_dtype | |||
return _get_amp_low_prec_dtype() | |||
@low_prec_dtype.setter | |||
def low_prec_dtype(mod, dtype: str): | |||
global _low_prec_dtype | |||
_low_prec_dtype = dtype | |||
_set_amp_low_prec_dtype(dtype) |
@@ -25,7 +25,6 @@ from .utils import ( | |||
astensor1d, | |||
astype, | |||
cast_tensors, | |||
convert_inputs, | |||
make_shape_tuple, | |||
subgraph, | |||
) | |||
@@ -40,38 +39,6 @@ def _elwise_apply(args, mode): | |||
def _elwise(*args, mode): | |||
args = convert_inputs(*args) | |||
if ( | |||
mode | |||
in ( | |||
_ElwMod.TRUE_DIV, | |||
_ElwMod.EXP, | |||
_ElwMod.POW, | |||
_ElwMod.LOG, | |||
_ElwMod.EXPM1, | |||
_ElwMod.LOG1P, | |||
_ElwMod.ACOS, | |||
_ElwMod.ASIN, | |||
_ElwMod.ATAN2, | |||
_ElwMod.COS, | |||
_ElwMod.SIN, | |||
_ElwMod.LOG_SUM_EXP, | |||
) | |||
and ( | |||
amp._enabled | |||
or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) | |||
) | |||
or mode in (_ElwMod.TANH,) | |||
and np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) | |||
): | |||
# autocast to FP32 to maintain precision | |||
# or to avoid op's not supporting all int args | |||
args = cast_tensors(*args, promote=True) | |||
if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND,) and np.issubdtype( | |||
args[0].dtype, np.integer | |||
): | |||
return args[0] | |||
return _elwise_apply(args, mode) | |||
@@ -504,10 +471,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
def _reduce(mode): | |||
def f(self, axis=None, keepdims: bool = False): | |||
data = self | |||
if mode == "mean": | |||
data = data.astype("float32") | |||
elif self.dtype == np.bool_: | |||
data = data.astype("int32") | |||
if axis is None: | |||
assert not keepdims, "can not set axis=None and keepdims=True" | |||
result = _reduce_to_scalar(builtin.Reduce(mode=mode), data) | |||
@@ -526,9 +489,6 @@ def _reduce(mode): | |||
if not keepdims: | |||
result = _remove_axis(result, axis) | |||
if self.dtype == np.bool_: | |||
if mode in ["min", "max"]: | |||
result = result.astype("bool") | |||
return result | |||
return f | |||
@@ -16,6 +16,8 @@ from .._imperative_rt import make_const | |||
from .._imperative_rt.core2 import ( | |||
SymbolVar, | |||
Tensor, | |||
_get_convert_inputs, | |||
_set_convert_inputs, | |||
apply, | |||
dtype_promotion, | |||
get_device, | |||
@@ -27,15 +29,13 @@ from .._wrap import as_device | |||
from ..autodiff.grad import Function | |||
from ..ops import builtin | |||
from ..ops.special import Const | |||
from .amp import _high_prec_dtype, _low_prec_dtype | |||
from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype | |||
from .dtype import is_dtype_equal, is_quantize | |||
_enable_convert_inputs = True | |||
def get_convert_inputs(): | |||
r"""get the curerent state of `_enable_convert_inputs`""" | |||
return _enable_convert_inputs | |||
return _get_convert_inputs() | |||
def set_convert_inputs(flag): | |||
@@ -44,10 +44,7 @@ def set_convert_inputs(flag): | |||
`_enable_convert_inputs` is set to `False`, otherwise enabled. This function is for | |||
internal use only, and should be removed when the tensor-like system is refactored. | |||
""" | |||
global _enable_convert_inputs | |||
backup = _enable_convert_inputs | |||
_enable_convert_inputs = flag | |||
return backup | |||
return _set_convert_inputs(flag) | |||
def concatenate(inputs, axis=0, *, device=None): | |||
@@ -75,7 +72,7 @@ def convert_single_value(v, *, dtype=None, device=None): | |||
def convert_inputs(*args, device=None): | |||
if not _enable_convert_inputs: | |||
if not _get_convert_inputs(): | |||
return args | |||
dtype = dtype_promotion(args) | |||
@@ -109,9 +106,9 @@ def convert_inputs(*args, device=None): | |||
def cast_tensors(*args, promote=False): | |||
if promote: | |||
dtype = _high_prec_dtype | |||
dtype = _get_amp_high_prec_dtype() | |||
else: | |||
dtype = _low_prec_dtype | |||
dtype = _get_amp_low_prec_dtype() | |||
return tuple(arg.astype(dtype) if arg is not None else None for arg in args) | |||
@@ -16,6 +16,7 @@ from ..core.tensor.array_method import _elwise | |||
from ..core.tensor.utils import convert_inputs | |||
from ..tensor import Tensor | |||
from ..utils.deprecation import deprecated_func | |||
from .tensor_cache import get_scalar_one | |||
__all__ = [ | |||
"abs", | |||
@@ -359,7 +360,11 @@ def asin(x): | |||
def atan(x): | |||
r"""Element-wise `inverse tangent`.""" | |||
return _elwise(x, 1, mode=Elemwise.Mode.ATAN2) | |||
return _elwise( | |||
x, | |||
get_scalar_one("float32", x.device if isinstance(x, Tensor) else None), | |||
mode=Elemwise.Mode.ATAN2, | |||
) | |||
def atan2(y, x): | |||
@@ -253,15 +253,6 @@ def conv2d( | |||
conv_mode.lower() == "cross_correlation" | |||
or conv_mode.name == "CROSS_CORRELATION" | |||
) | |||
if amp._enabled: | |||
compute_mode = "float32" | |||
inp, weight, bias = cast_tensors(inp, weight, bias) | |||
else: | |||
dtype = dtype_promotion(inp, weight) | |||
if inp.dtype != dtype: | |||
inp = inp.astype(dtype) | |||
if weight.dtype != dtype: | |||
weight = weight.astype(dtype) | |||
stride_h, stride_w = expand_hw(stride) | |||
pad_h, pad_w = expand_hw(padding) | |||
@@ -1328,29 +1319,32 @@ def batch_norm( | |||
inplace: whether to update ``running_mean`` and ``running_var`` | |||
inplace or return new tensors. Default: True | |||
""" | |||
if inp.ndim != 4: | |||
raise NotImplementedError("batch_norm for ndim != 4") | |||
if param_dim == "dim_1c11": | |||
C = inp.shape[1] | |||
pshape = (1, C, 1, 1) | |||
elif param_dim == "dim_111c": | |||
C = inp.shape[3] | |||
pshape = (1, 1, 1, C) | |||
else: | |||
raise ValueError("Invalid param_dim {}".format(param_dim)) | |||
def make_full_if_none(x, value): | |||
x_ndim = None if x is None else x.ndim | |||
# in general case, x will be returned here directly | |||
if x_ndim is not None and x_ndim != 1: | |||
return x | |||
if param_dim == "dim_1c11": | |||
C = inp.shape[1] | |||
pshape = (1, C, 1, 1) | |||
elif param_dim == "dim_111c": | |||
C = inp.shape[3] | |||
pshape = (1, 1, 1, C) | |||
else: | |||
raise ValueError("Invalid param_dim {}".format(param_dim)) | |||
if x is None: | |||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)() | |||
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | |||
(result,) = apply(builtin.Broadcast(), x, shape) | |||
return result | |||
elif x.ndim == 1: | |||
else: | |||
assert x_ndim == 1 | |||
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | |||
(result,) = apply(builtin.Reshape(), x, shape) | |||
return result | |||
return x | |||
has_mean = running_mean is not None | |||
has_var = running_var is not None | |||
@@ -1359,16 +1353,6 @@ def batch_norm( | |||
assert has_mean, "running_mean must be provided in inference mode" | |||
assert has_var, "running_var must be provided in inference mode" | |||
if has_mean and running_mean.ndim != 4: | |||
raise ValueError | |||
if has_var and running_var.ndim != 4: | |||
raise ValueError | |||
if amp._enabled: | |||
inp = inp.astype("float16") | |||
weight, bias, running_mean, running_var = cast_tensors( | |||
weight, bias, running_mean, running_var, promote=True | |||
) | |||
weight = make_full_if_none(weight, 1) | |||
bias = make_full_if_none(bias, 0) | |||
@@ -0,0 +1,34 @@ | |||
from ..core.ops.special import Const | |||
from ..jit.tracing import is_tracing | |||
small_tensor_cache = {} | |||
def _get_scalar_tensor_with_value(value, dtype=None, device=None): | |||
global small_tensor_cache | |||
if is_tracing(): | |||
(ret,) = Const(value, dtype=dtype, device=device)() | |||
else: | |||
cache_key = (value, dtype, device) | |||
if cache_key not in small_tensor_cache: | |||
(ret,) = Const(value, dtype=dtype, device=device)() | |||
small_tensor_cache[cache_key] = ret | |||
else: | |||
ret = small_tensor_cache[cache_key] | |||
return ret | |||
def get_scalar_zero(dtype=None, device=None): | |||
return _get_scalar_tensor_with_value(0, dtype, device) | |||
def get_scalar_zero_point_five(dtype=None, device=None): | |||
return _get_scalar_tensor_with_value(0.5, dtype, device) | |||
def get_scalar_one(dtype=None, device=None): | |||
return _get_scalar_tensor_with_value(1, dtype, device) | |||
def get_scalar_two(dtype=None, device=None): | |||
return _get_scalar_tensor_with_value(2, dtype, device) |
@@ -15,6 +15,7 @@ | |||
#include "megbrain/imperative/ops/backward_graph.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include "megbrain/imperative/profiler.h" | |||
#include "megbrain/imperative/transformations/dtype_promote.h" | |||
#include "megbrain/imperative/transformations/eval.h" | |||
#include "megbrain/imperative/transformations/lazy.h" | |||
#include "megbrain/imperative/transformations/scalar.h" | |||
@@ -59,16 +60,19 @@ struct SymbolVarContext { | |||
TransformationContext context; | |||
std::shared_ptr<SymbolTransformation> symbol_tsf; | |||
std::shared_ptr<ScalarTransformation> scalar_tsf; | |||
std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf; | |||
SymbolVarContext(cg::ComputingGraph* graph) { | |||
symbol_tsf = std::make_shared<SymbolTransformation>(graph); | |||
scalar_tsf = std::make_shared<ScalarTransformation>(); | |||
dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>(); | |||
Transformation::swap_context(context); | |||
} | |||
void init() { | |||
symbol_tsf->register_at(Transformation::top()); | |||
scalar_tsf->register_at(Transformation::top()); | |||
dtype_promote_tsf->register_at(Transformation::top()); | |||
} | |||
ValueRef symvar2val(py::handle py_symbol_var) { | |||
@@ -110,6 +114,9 @@ REGISTE_APPLY_FUNC(cpp_astensor1d) | |||
#undef REGISTE_APPLY_FUNC | |||
PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs); | |||
CompNode _get_device(PyObject* const* args, size_t nargs); | |||
PyObject* py_apply( | |||
PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) { | |||
try { | |||
@@ -133,19 +140,59 @@ PyObject* py_apply( | |||
auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>(); | |||
SmallVector<ValueRef, 8> tensors(nargs); | |||
bool is_symbol_var = (!TensorWrapper::try_cast(args[0])) && | |||
py::isinstance<PySymbolVar>(py::handle(args[0])); | |||
if (is_symbol_var) { | |||
SmallVector<bool, 8> is_symbol_var(nargs, false); | |||
ComputingGraph* cg = nullptr; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
if ((!TensorWrapper::try_cast(args[i])) && | |||
py::isinstance<PySymbolVar>(py::handle(args[i]))) { | |||
is_symbol_var[i] = true; | |||
ComputingGraph* cur_cg = | |||
py::handle(args[i]).cast<PySymbolVar*>()->m_node->owner_graph(); | |||
if (cg == nullptr) { | |||
cg = cur_cg; | |||
} else { | |||
mgb_assert(cg == cur_cg); | |||
} | |||
} | |||
} | |||
mgb::CompNode target_cn; | |||
mgb::DType target_dtype; | |||
auto convert_pyinput_to_tensor = [&](size_t i) -> ValueRef { | |||
if (!target_dtype.valid()) { | |||
target_dtype = npy::dtype_np2mgb_descr(_dtype_promotion(args, nargs)); | |||
target_cn = _get_device(args, nargs); | |||
} | |||
HostTensorND ht(target_cn); | |||
ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype); | |||
if (PyArray_Check(args[i])) { // non scaler | |||
return imperative::apply( | |||
CreateTensor(CreateTensor::Const, target_cn, ht.layout()), | |||
HostStorage::make(ht.storage()))[0]; | |||
} else { // scaler | |||
return imperative::apply( | |||
CreateTensor(CreateTensor::Const, target_cn, target_dtype, {}), | |||
HostStorage::make(ht.storage()))[0]; | |||
} | |||
}; | |||
if (cg != nullptr) { | |||
// swap to a special context to reuse scalar handle | |||
SymbolVarContext context( | |||
py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph()); | |||
size_t symbol_var_idx = 8; | |||
SymbolVarContext context(cg); | |||
context.init(); | |||
for (size_t i = 0; i < nargs; ++i) { | |||
tensors[i] = context.symvar2val(args[i]); | |||
if (is_symbol_var[i]) { | |||
symbol_var_idx = i; | |||
tensors[i] = context.symvar2val(args[i]); | |||
} else { | |||
tensors[i] = convert_pyinput_to_tensor(i); | |||
} | |||
} | |||
auto outputs = imperative::apply(*op, tensors); | |||
auto ret = pybind11::tuple(outputs.size()); | |||
auto typeobj = py::handle(args[0]).get_type(); | |||
auto typeobj = py::handle(args[symbol_var_idx]).get_type(); | |||
for (size_t i = 0; i < outputs.size(); ++i) { | |||
ret[i] = context.val2symvar(typeobj, outputs[i]); | |||
} | |||
@@ -156,13 +203,7 @@ PyObject* py_apply( | |||
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | |||
tensors[i] = tw->m_tensor->data(); | |||
} else { | |||
PyErr_SetString( | |||
PyExc_TypeError, | |||
ssprintf( | |||
"op %s expect type Tensor as inputs, got %s actually", | |||
op->make_name().c_str(), Py_TYPE(args[i])->tp_name) | |||
.c_str()); | |||
return nullptr; | |||
tensors[i] = convert_pyinput_to_tensor(i); | |||
} | |||
} | |||
@@ -616,6 +657,8 @@ void init_tensor(py::module m) { | |||
std::shared_ptr<Channel>(channel, [](Channel*) {}))); | |||
transformations.register_at<Segment::Scalar>( | |||
std::make_shared<ScalarTransformation>()); | |||
transformations.register_at<Segment::DTypePromote>( | |||
std::make_shared<DTypePromoteTransformation>()); | |||
static py::exception<interpreter::AsyncError> py_async_error( | |||
m, "AsyncError", PyExc_RuntimeError); | |||
@@ -1137,6 +1180,63 @@ void init_tensor(py::module m) { | |||
m.def("reset_stats", [] { imperative::Stats::reset(); }); | |||
m.def("_get_convert_inputs", | |||
[]() -> bool { return DTypePromoteCfg::convert_input_enabled; }); | |||
m.def("_set_convert_inputs", [](bool flag) -> bool { | |||
bool ret = DTypePromoteCfg::convert_input_enabled; | |||
DTypePromoteCfg::convert_input_enabled = flag; | |||
return ret; | |||
}); | |||
m.def("_get_amp_dtype_autocast", | |||
[]() -> bool { return DTypePromoteCfg::amp_dtype_autocast_enabled; }); | |||
m.def("_set_amp_dtype_autocast", [](bool flag) -> bool { | |||
bool ret = DTypePromoteCfg::amp_dtype_autocast_enabled; | |||
DTypePromoteCfg::amp_dtype_autocast_enabled = flag; | |||
return ret; | |||
}); | |||
static auto get_amp_prec_dtype = [](bool is_high) -> std::string { | |||
DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype | |||
: DTypePromoteCfg::amp_low_prec_dtype; | |||
mgb_assert(target.category() == DTypeCategory::FLOAT); | |||
std::string ret = target.name(); | |||
transform(ret.begin(), ret.end(), ret.begin(), ::tolower); | |||
return ret; | |||
}; | |||
static auto set_amp_prec_dtype = [](bool is_high, | |||
std::string dtype_name) -> std::string { | |||
DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype | |||
: DTypePromoteCfg::amp_low_prec_dtype; | |||
std::string ret = target.name(); | |||
if (dtype_name == "float32") { | |||
target = dtype::Float32(); | |||
} else if (dtype_name == "float16") { | |||
target = dtype::Float16(); | |||
} else if (dtype_name == "bfloat16") { | |||
target = dtype::BFloat16(); | |||
} else { | |||
mgb_assert( | |||
false, "casted type of amp should be float, but you give %s\n", | |||
dtype_name.c_str()); | |||
} | |||
transform(ret.begin(), ret.end(), ret.begin(), ::tolower); | |||
return ret; | |||
}; | |||
m.def("_get_amp_high_prec_dtype", | |||
[]() -> std::string { return get_amp_prec_dtype(true); }); | |||
m.def("_set_amp_high_prec_dtype", [](std::string dtype_name) -> std::string { | |||
return set_amp_prec_dtype(true, dtype_name); | |||
}); | |||
m.def("_get_amp_low_prec_dtype", | |||
[]() -> std::string { return get_amp_prec_dtype(false); }); | |||
m.def("_set_amp_low_prec_dtype", [](std::string dtype_name) -> std::string { | |||
return set_amp_prec_dtype(false, dtype_name); | |||
}); | |||
py::register_exception<TraceError>(m, "TraceError"); | |||
} | |||
@@ -26,12 +26,13 @@ struct TransformationManager { | |||
enum Segment { | |||
ModuleTrace, | |||
Grad, | |||
DTypePromote, | |||
Scalar, | |||
Trace, | |||
Eval, | |||
}; | |||
std::array<std::vector<std::shared_ptr<Transformation>>, 5> segments; | |||
std::array<std::vector<std::shared_ptr<Transformation>>, 6> segments; | |||
template <Segment segment> | |||
void register_at(std::shared_ptr<Transformation> transformation) { | |||
@@ -14,20 +14,20 @@ def test_grad_scaler(): | |||
assert amp.enabled == enabled | |||
assert origin_amp._enabled == enabled | |||
assert amp.low_prec_dtype == low | |||
assert origin_amp._low_prec_dtype == low | |||
assert origin_amp._get_amp_low_prec_dtype() == low | |||
assert amp.high_prec_dtype == high | |||
assert origin_amp._high_prec_dtype == high | |||
assert origin_amp._get_amp_high_prec_dtype() == high | |||
origin_enabled = amp.enabled | |||
origin_high = amp.high_prec_dtype | |||
origin_low = amp.low_prec_dtype | |||
with amp.autocast(low_prec_dtype="low", high_prec_dtype="high"): | |||
check(True, "low", "high") | |||
with amp.autocast(low_prec_dtype="float16", high_prec_dtype="float32"): | |||
check(True, "float16", "float32") | |||
check(origin_enabled, origin_low, origin_high) | |||
amp.enabled = True | |||
amp.high_prec_dtype = "high" | |||
amp.low_prec_dtype = "low" | |||
check(True, "low", "high") | |||
amp.high_prec_dtype = "float32" | |||
amp.low_prec_dtype = "float16" | |||
check(True, "float16", "float32") | |||
amp.enabled = origin_enabled | |||
amp.high_prec_dtype = origin_high | |||
amp.low_prec_dtype = origin_low | |||
@@ -0,0 +1,251 @@ | |||
#include "megbrain/imperative/transformations/dtype_promote.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
namespace mgb::imperative { | |||
bool DTypePromoteCfg::convert_input_enabled = true; | |||
bool DTypePromoteCfg::amp_dtype_autocast_enabled = false; | |||
DType DTypePromoteCfg::amp_high_prec_dtype = dtype::Float32(); | |||
DType DTypePromoteCfg::amp_low_prec_dtype = dtype::Float16(); | |||
namespace { | |||
// TODO: ScalarRule and DTypePromoteRule should be unified | |||
using DTypePromoteRule = std::function<ValueRefList(const OpDef&, Span<ValueRef>)>; | |||
static std::unordered_map<Typeinfo*, DTypePromoteRule> dtype_promotion_rules; | |||
template <typename T> | |||
void register_dtype_promote_rule(const DTypePromoteRule& rule) { | |||
dtype_promotion_rules[T::typeinfo()] = [rule](const OpDef& def, | |||
Span<ValueRef> inputs) { | |||
return rule(def.cast_final_safe<T>(), inputs); | |||
}; | |||
} | |||
bool is_quantized_dtype(const DType& dtype) { | |||
return dtype.category() == DTypeCategory::QUANTIZED; | |||
} | |||
bool is_all_integer(const SmallVector<DType>& dtypes) { | |||
for (size_t i = 0; i < dtypes.size(); ++i) { | |||
if (dtypes[i].category() != DTypeCategory::INT) { | |||
return false; | |||
} | |||
} | |||
return true; | |||
} | |||
SmallVector<DType> get_value_dtypes(const Span<ValueRef> inputs) { | |||
SmallVector<DType> dtypes(inputs.size()); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
dtypes[i] = *(inputs[i].dtype()); | |||
} | |||
return dtypes; | |||
} | |||
mgb::DType get_promoted_dtype(const SmallVector<DType>& dtypes) { | |||
if (dtypes.size() == 0) { | |||
mgb_assert(false, "there is no input for operator, dtype promote failed"); | |||
} | |||
mgb::DType ret = dtypes[0]; | |||
for (size_t i = 1; i < dtypes.size(); ++i) { | |||
ret = mgb::dtype_promotion(ret, dtypes[i]); | |||
} | |||
return ret; | |||
} | |||
ValueRefList elemwise_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
auto&& elem_op = op.cast_final_safe<Elemwise>(); | |||
SmallVector<DType> dtypes(inputs.size()); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
dtypes[i] = *(inputs[i].dtype()); | |||
} | |||
ValueRefList converted(inputs.size()); | |||
mgb::DType target_dtype = get_promoted_dtype(dtypes); | |||
// TODO: we can save the dtypes of inputs here and perform TypeCvt at the end of | |||
// this function, rather than perform TypeCvt eagerly. But for the compatibility, we | |||
// implement this function with the similar process as the python version and | |||
// perform TypeCvt here, so we maybe do TypeCvt several times in these function | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (!is_quantized_dtype(dtypes[i]) && dtypes[i] != target_dtype && | |||
DTypePromoteCfg::convert_input_enabled) { | |||
converted[i] = imperative::apply( | |||
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||
dtypes[i] = target_dtype; | |||
} else { | |||
converted[i] = inputs[i]; | |||
} | |||
} | |||
static std::unordered_set<Elemwise::Mode> cast_case1 = { | |||
Elemwise::Mode::TRUE_DIV, Elemwise::Mode::EXP, | |||
Elemwise::Mode::POW, Elemwise::Mode::LOG, | |||
Elemwise::Mode::EXPM1, Elemwise::Mode::LOG1P, | |||
Elemwise::Mode::ACOS, Elemwise::Mode::ASIN, | |||
Elemwise::Mode::ATAN2, Elemwise::Mode::COS, | |||
Elemwise::Mode::SIN, Elemwise::Mode::LOG_SUM_EXP, | |||
}; | |||
static std::unordered_set<Elemwise::Mode> cast_case2 = { | |||
Elemwise::Mode::TANH, | |||
}; | |||
auto cast_to_high_prec = [&]() { | |||
for (size_t i = 0; i < dtypes.size(); ++i) { | |||
if (dtypes[i] != DTypePromoteCfg::amp_high_prec_dtype) { | |||
converted[i] = imperative::apply( | |||
ApplyOp(*TypeCvt::make(DTypePromoteCfg::amp_high_prec_dtype)), | |||
converted[i])[0]; | |||
dtypes[i] = DTypePromoteCfg::amp_high_prec_dtype; | |||
} | |||
} | |||
}; | |||
if (cast_case1.find(elem_op.mode) != cast_case1.end()) { | |||
if (DTypePromoteCfg::amp_dtype_autocast_enabled || is_all_integer(dtypes)) { | |||
cast_to_high_prec(); | |||
} | |||
} | |||
if (cast_case2.find(elem_op.mode) != cast_case2.end()) { | |||
if (is_all_integer(dtypes)) { | |||
cast_to_high_prec(); | |||
} | |||
} | |||
static std::unordered_set<Elemwise::Mode> cast_case3 = { | |||
Elemwise::Mode::CEIL, Elemwise::Mode::FLOOR, Elemwise::Mode::ROUND}; | |||
if (cast_case3.find(elem_op.mode) != cast_case3.end()) { | |||
if (dtypes[0].category() == DTypeCategory::INT) { | |||
return converted; | |||
} | |||
} | |||
return imperative::apply(op, converted); | |||
} | |||
ValueRefList reduce_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
auto&& reduce_op = op.cast_final_safe<Reduce>(); | |||
DType org_dtype = *(inputs[0].dtype()); | |||
DType target_dtype = org_dtype; | |||
ValueRefList converted(inputs.begin(), inputs.end()); | |||
if (reduce_op.mode == Reduce::Mode::MEAN) { | |||
target_dtype = dtype::Float32(); | |||
} else if (org_dtype.category() == DTypeCategory::BOOL) { | |||
target_dtype = dtype::Int32(); | |||
} | |||
if (target_dtype != org_dtype) { | |||
converted[0] = | |||
imperative::apply(ApplyOp(*TypeCvt::make(target_dtype)), inputs[0])[0]; | |||
} | |||
ValueRefList ret = imperative::apply(op, converted); | |||
if (org_dtype.category() == DTypeCategory::BOOL) { | |||
if (reduce_op.mode == Reduce::Mode::MIN || | |||
reduce_op.mode == Reduce::Mode::MAX) { | |||
ret[0] = imperative::apply( | |||
ApplyOp(*TypeCvt::make(dtype::Bool())), ret[0])[0]; | |||
} | |||
} | |||
return ret; | |||
} | |||
ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
auto&& conv_op = const_cast<Convolution&>(op.cast_final_safe<Convolution>()); | |||
SmallVector<DType> dtypes = get_value_dtypes(inputs); | |||
mgb::DType target_dtype; | |||
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { | |||
conv_op.compute_mode = Convolution::ComputeMode::FLOAT32; | |||
target_dtype = DTypePromoteCfg::amp_low_prec_dtype; | |||
} else { | |||
target_dtype = get_promoted_dtype(dtypes); | |||
} | |||
ValueRefList converted(inputs.size()); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (dtypes[i] != target_dtype) { | |||
converted[i] = imperative::apply( | |||
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||
} else { | |||
converted[i] = inputs[i]; | |||
} | |||
} | |||
return imperative::apply(op, converted); | |||
} | |||
ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { | |||
mgb_assert(inputs.size() > 0); | |||
ValueRefList converted(inputs.size()); | |||
converted[0] = imperative::apply( | |||
ApplyOp(*TypeCvt::make(dtype::Float16())), inputs[0])[0]; | |||
for (size_t i = 1; i < inputs.size(); ++i) { | |||
DType idtype = *(inputs[i].dtype()); | |||
if (idtype != DTypePromoteCfg::amp_high_prec_dtype) { | |||
converted[i] = imperative::apply( | |||
ApplyOp(*TypeCvt::make(DTypePromoteCfg::amp_high_prec_dtype)), | |||
inputs[i])[0]; | |||
} else { | |||
converted[i] = inputs[i]; | |||
} | |||
} | |||
return imperative::apply(op, converted); | |||
} | |||
return imperative::apply(op, inputs); | |||
} | |||
struct DTypePromoteRuleRegistry { | |||
DTypePromoteRuleRegistry() { | |||
register_dtype_promote_rule<Elemwise>(elemwise_rule); | |||
register_dtype_promote_rule<Reduce>(reduce_rule); | |||
register_dtype_promote_rule<Convolution>(convolution_rule); | |||
register_dtype_promote_rule<BatchNorm>(batch_norm_rule); | |||
} | |||
} register_helper; | |||
} // namespace | |||
ValueRefList DTypePromoteTransformation::apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) { | |||
if (auto apply_op = op.as<ApplyOp>()) { | |||
auto iter = dtype_promotion_rules.find(apply_op->op().dyn_typeinfo()); | |||
if (iter != dtype_promotion_rules.end()) { | |||
return iter->second(apply_op->op(), inputs); | |||
} else { | |||
return imperative::apply(op, inputs); | |||
} | |||
} | |||
return imperative::apply(op, inputs); | |||
} | |||
ValueRef DTypePromoteTransformation::unwrap(ValueRef value) { | |||
return value; | |||
} | |||
std::string DTypePromoteTransformation::name() const { | |||
return "DTypePromoteTransformation"; | |||
} | |||
void DTypePromoteTransformation::on_register() { | |||
// printf("DTypePromoteTransformation has been registered\n"); | |||
} | |||
void DTypePromoteTransformation::on_unregister() noexcept { | |||
// printf("DTypePromoteTransformation has been unregistered\n"); | |||
} | |||
} // namespace mgb::imperative |
@@ -0,0 +1,26 @@ | |||
#pragma once | |||
#include "megbrain/imperative/dispatch.h" | |||
#include "megbrain/imperative/value.h" | |||
namespace mgb::imperative { | |||
class DTypePromoteTransformation final : public Transformation { | |||
private: | |||
public: | |||
ValueRefList apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) override; | |||
ValueRef unwrap(ValueRef value) override; | |||
std::string name() const override; | |||
void on_register() override; | |||
void on_unregister() noexcept override; | |||
}; | |||
struct DTypePromoteCfg { | |||
static bool convert_input_enabled; | |||
static bool amp_dtype_autocast_enabled; | |||
static DType amp_high_prec_dtype; | |||
static DType amp_low_prec_dtype; | |||
}; | |||
} // namespace mgb::imperative |