GitOrigin-RevId: 01b5324392
tags/v1.9.0
@@ -28,6 +28,16 @@ void BNForward::check_exec( | |||||
const TensorLayout& variance, const TensorLayout& batch_mean, | const TensorLayout& variance, const TensorLayout& batch_mean, | ||||
const TensorLayout& batch_inv_variance, const TensorLayout& dst, | const TensorLayout& batch_inv_variance, const TensorLayout& dst, | ||||
size_t workspace_in_bytes, size_t reserve_in_bytes) { | size_t workspace_in_bytes, size_t reserve_in_bytes) { | ||||
// moving some python assert to dnn to decrease the assert overhead | |||||
megdnn_assert( | |||||
src.ndim == 4, | |||||
"ndim of the input tensor for batch_norm should be 4, but you give %zu", | |||||
src.ndim); | |||||
megdnn_assert(bn_scale.ndim == 4, "expect 4, get %zu\n", bn_scale.ndim); | |||||
megdnn_assert(bn_bias.ndim == 4, "expect 4, get %zu\n", bn_bias.ndim); | |||||
megdnn_assert_eq_layout(bn_scale, bn_bias); | |||||
megdnn_assert_eq_layout(batch_mean, batch_inv_variance); | |||||
megdnn_assert_contiguous(src); | megdnn_assert_contiguous(src); | ||||
megdnn_assert_eq_layout(src, dst); | megdnn_assert_eq_layout(src, dst); | ||||
megdnn_assert_eq_layout(bn_scale, bn_bias); | megdnn_assert_eq_layout(bn_scale, bn_bias); | ||||
@@ -58,16 +58,19 @@ class autocast: | |||||
self._origin_low = None | self._origin_low = None | ||||
def __enter__(self): | def __enter__(self): | ||||
self._origin_enabled, amp._enabled = amp._enabled, self.enabled | |||||
self._origin_high = amp._high_prec_dtype | |||||
amp._high_prec_dtype = self.high_prec_dtype | |||||
self._origin_low = amp._low_prec_dtype | |||||
amp._low_prec_dtype = self.low_prec_dtype | |||||
self._origin_enabled = amp._enabled | |||||
self._origin_high = amp._get_amp_high_prec_dtype() | |||||
self._origin_low = amp._get_amp_low_prec_dtype() | |||||
amp._enabled = self.enabled | |||||
amp._set_amp_dtype_autocast(self.enabled) | |||||
amp._set_amp_high_prec_dtype(self.high_prec_dtype) | |||||
amp._set_amp_low_prec_dtype(self.low_prec_dtype) | |||||
def __exit__(self, *args): | def __exit__(self, *args): | ||||
amp._enabled = self._origin_enabled | amp._enabled = self._origin_enabled | ||||
amp._high_prec_dtype = self._origin_high | |||||
amp._low_prec_dtype = self._origin_low | |||||
amp._set_amp_dtype_autocast(self._origin_enabled) | |||||
amp._set_amp_high_prec_dtype(self._origin_high) | |||||
amp._set_amp_low_prec_dtype(self._origin_low) | |||||
def __call__(self, func): | def __call__(self, func): | ||||
@functools.wraps(func) | @functools.wraps(func) | ||||
@@ -5,9 +5,18 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .._imperative_rt.core2 import ( | |||||
_get_amp_dtype_autocast, | |||||
_get_amp_high_prec_dtype, | |||||
_get_amp_low_prec_dtype, | |||||
_set_amp_dtype_autocast, | |||||
_set_amp_high_prec_dtype, | |||||
_set_amp_low_prec_dtype, | |||||
) | |||||
_enabled = False | _enabled = False | ||||
_high_prec_dtype = "float32" | |||||
_low_prec_dtype = "float16" | |||||
_set_amp_dtype_autocast(_enabled) | |||||
@property | @property | ||||
@@ -28,6 +37,7 @@ def enabled(mod): | |||||
def enabled(mod, enabled: bool): | def enabled(mod, enabled: bool): | ||||
global _enabled | global _enabled | ||||
_enabled = enabled | _enabled = enabled | ||||
_set_amp_dtype_autocast(_enabled) | |||||
@property | @property | ||||
@@ -42,13 +52,12 @@ def high_prec_dtype(mod): | |||||
import megengine as mge | import megengine as mge | ||||
mge.amp.high_prec_dtype = "float32" | mge.amp.high_prec_dtype = "float32" | ||||
""" | """ | ||||
return _high_prec_dtype | |||||
return _get_amp_high_prec_dtype() | |||||
@high_prec_dtype.setter | @high_prec_dtype.setter | ||||
def high_prec_dtype(mod, dtype: str): | def high_prec_dtype(mod, dtype: str): | ||||
global _high_prec_dtype | |||||
_high_prec_dtype = dtype | |||||
_set_amp_high_prec_dtype(dtype) | |||||
@property | @property | ||||
@@ -63,10 +72,9 @@ def low_prec_dtype(mod): | |||||
import megengine as mge | import megengine as mge | ||||
mge.amp.low_prec_dtype = "float16" | mge.amp.low_prec_dtype = "float16" | ||||
""" | """ | ||||
return _low_prec_dtype | |||||
return _get_amp_low_prec_dtype() | |||||
@low_prec_dtype.setter | @low_prec_dtype.setter | ||||
def low_prec_dtype(mod, dtype: str): | def low_prec_dtype(mod, dtype: str): | ||||
global _low_prec_dtype | |||||
_low_prec_dtype = dtype | |||||
_set_amp_low_prec_dtype(dtype) |
@@ -25,7 +25,6 @@ from .utils import ( | |||||
astensor1d, | astensor1d, | ||||
astype, | astype, | ||||
cast_tensors, | cast_tensors, | ||||
convert_inputs, | |||||
make_shape_tuple, | make_shape_tuple, | ||||
subgraph, | subgraph, | ||||
) | ) | ||||
@@ -40,38 +39,6 @@ def _elwise_apply(args, mode): | |||||
def _elwise(*args, mode): | def _elwise(*args, mode): | ||||
args = convert_inputs(*args) | |||||
if ( | |||||
mode | |||||
in ( | |||||
_ElwMod.TRUE_DIV, | |||||
_ElwMod.EXP, | |||||
_ElwMod.POW, | |||||
_ElwMod.LOG, | |||||
_ElwMod.EXPM1, | |||||
_ElwMod.LOG1P, | |||||
_ElwMod.ACOS, | |||||
_ElwMod.ASIN, | |||||
_ElwMod.ATAN2, | |||||
_ElwMod.COS, | |||||
_ElwMod.SIN, | |||||
_ElwMod.LOG_SUM_EXP, | |||||
) | |||||
and ( | |||||
amp._enabled | |||||
or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) | |||||
) | |||||
or mode in (_ElwMod.TANH,) | |||||
and np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) | |||||
): | |||||
# autocast to FP32 to maintain precision | |||||
# or to avoid op's not supporting all int args | |||||
args = cast_tensors(*args, promote=True) | |||||
if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND,) and np.issubdtype( | |||||
args[0].dtype, np.integer | |||||
): | |||||
return args[0] | |||||
return _elwise_apply(args, mode) | return _elwise_apply(args, mode) | ||||
@@ -504,10 +471,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||||
def _reduce(mode): | def _reduce(mode): | ||||
def f(self, axis=None, keepdims: bool = False): | def f(self, axis=None, keepdims: bool = False): | ||||
data = self | data = self | ||||
if mode == "mean": | |||||
data = data.astype("float32") | |||||
elif self.dtype == np.bool_: | |||||
data = data.astype("int32") | |||||
if axis is None: | if axis is None: | ||||
assert not keepdims, "can not set axis=None and keepdims=True" | assert not keepdims, "can not set axis=None and keepdims=True" | ||||
result = _reduce_to_scalar(builtin.Reduce(mode=mode), data) | result = _reduce_to_scalar(builtin.Reduce(mode=mode), data) | ||||
@@ -526,9 +489,6 @@ def _reduce(mode): | |||||
if not keepdims: | if not keepdims: | ||||
result = _remove_axis(result, axis) | result = _remove_axis(result, axis) | ||||
if self.dtype == np.bool_: | |||||
if mode in ["min", "max"]: | |||||
result = result.astype("bool") | |||||
return result | return result | ||||
return f | return f | ||||
@@ -16,6 +16,8 @@ from .._imperative_rt import make_const | |||||
from .._imperative_rt.core2 import ( | from .._imperative_rt.core2 import ( | ||||
SymbolVar, | SymbolVar, | ||||
Tensor, | Tensor, | ||||
_get_convert_inputs, | |||||
_set_convert_inputs, | |||||
apply, | apply, | ||||
dtype_promotion, | dtype_promotion, | ||||
get_device, | get_device, | ||||
@@ -27,15 +29,13 @@ from .._wrap import as_device | |||||
from ..autodiff.grad import Function | from ..autodiff.grad import Function | ||||
from ..ops import builtin | from ..ops import builtin | ||||
from ..ops.special import Const | from ..ops.special import Const | ||||
from .amp import _high_prec_dtype, _low_prec_dtype | |||||
from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype | |||||
from .dtype import is_dtype_equal, is_quantize | from .dtype import is_dtype_equal, is_quantize | ||||
_enable_convert_inputs = True | |||||
def get_convert_inputs(): | def get_convert_inputs(): | ||||
r"""get the curerent state of `_enable_convert_inputs`""" | r"""get the curerent state of `_enable_convert_inputs`""" | ||||
return _enable_convert_inputs | |||||
return _get_convert_inputs() | |||||
def set_convert_inputs(flag): | def set_convert_inputs(flag): | ||||
@@ -44,10 +44,7 @@ def set_convert_inputs(flag): | |||||
`_enable_convert_inputs` is set to `False`, otherwise enabled. This function is for | `_enable_convert_inputs` is set to `False`, otherwise enabled. This function is for | ||||
internal use only, and should be removed when the tensor-like system is refactored. | internal use only, and should be removed when the tensor-like system is refactored. | ||||
""" | """ | ||||
global _enable_convert_inputs | |||||
backup = _enable_convert_inputs | |||||
_enable_convert_inputs = flag | |||||
return backup | |||||
return _set_convert_inputs(flag) | |||||
def concatenate(inputs, axis=0, *, device=None): | def concatenate(inputs, axis=0, *, device=None): | ||||
@@ -75,7 +72,7 @@ def convert_single_value(v, *, dtype=None, device=None): | |||||
def convert_inputs(*args, device=None): | def convert_inputs(*args, device=None): | ||||
if not _enable_convert_inputs: | |||||
if not _get_convert_inputs(): | |||||
return args | return args | ||||
dtype = dtype_promotion(args) | dtype = dtype_promotion(args) | ||||
@@ -109,9 +106,9 @@ def convert_inputs(*args, device=None): | |||||
def cast_tensors(*args, promote=False): | def cast_tensors(*args, promote=False): | ||||
if promote: | if promote: | ||||
dtype = _high_prec_dtype | |||||
dtype = _get_amp_high_prec_dtype() | |||||
else: | else: | ||||
dtype = _low_prec_dtype | |||||
dtype = _get_amp_low_prec_dtype() | |||||
return tuple(arg.astype(dtype) if arg is not None else None for arg in args) | return tuple(arg.astype(dtype) if arg is not None else None for arg in args) | ||||
@@ -16,6 +16,7 @@ from ..core.tensor.array_method import _elwise | |||||
from ..core.tensor.utils import convert_inputs | from ..core.tensor.utils import convert_inputs | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from ..utils.deprecation import deprecated_func | from ..utils.deprecation import deprecated_func | ||||
from .tensor_cache import get_scalar_one | |||||
__all__ = [ | __all__ = [ | ||||
"abs", | "abs", | ||||
@@ -359,7 +360,11 @@ def asin(x): | |||||
def atan(x): | def atan(x): | ||||
r"""Element-wise `inverse tangent`.""" | r"""Element-wise `inverse tangent`.""" | ||||
return _elwise(x, 1, mode=Elemwise.Mode.ATAN2) | |||||
return _elwise( | |||||
x, | |||||
get_scalar_one("float32", x.device if isinstance(x, Tensor) else None), | |||||
mode=Elemwise.Mode.ATAN2, | |||||
) | |||||
def atan2(y, x): | def atan2(y, x): | ||||
@@ -253,15 +253,6 @@ def conv2d( | |||||
conv_mode.lower() == "cross_correlation" | conv_mode.lower() == "cross_correlation" | ||||
or conv_mode.name == "CROSS_CORRELATION" | or conv_mode.name == "CROSS_CORRELATION" | ||||
) | ) | ||||
if amp._enabled: | |||||
compute_mode = "float32" | |||||
inp, weight, bias = cast_tensors(inp, weight, bias) | |||||
else: | |||||
dtype = dtype_promotion(inp, weight) | |||||
if inp.dtype != dtype: | |||||
inp = inp.astype(dtype) | |||||
if weight.dtype != dtype: | |||||
weight = weight.astype(dtype) | |||||
stride_h, stride_w = expand_hw(stride) | stride_h, stride_w = expand_hw(stride) | ||||
pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
@@ -1328,29 +1319,32 @@ def batch_norm( | |||||
inplace: whether to update ``running_mean`` and ``running_var`` | inplace: whether to update ``running_mean`` and ``running_var`` | ||||
inplace or return new tensors. Default: True | inplace or return new tensors. Default: True | ||||
""" | """ | ||||
if inp.ndim != 4: | |||||
raise NotImplementedError("batch_norm for ndim != 4") | |||||
if param_dim == "dim_1c11": | |||||
C = inp.shape[1] | |||||
pshape = (1, C, 1, 1) | |||||
elif param_dim == "dim_111c": | |||||
C = inp.shape[3] | |||||
pshape = (1, 1, 1, C) | |||||
else: | |||||
raise ValueError("Invalid param_dim {}".format(param_dim)) | |||||
def make_full_if_none(x, value): | def make_full_if_none(x, value): | ||||
x_ndim = None if x is None else x.ndim | |||||
# in general case, x will be returned here directly | |||||
if x_ndim is not None and x_ndim != 1: | |||||
return x | |||||
if param_dim == "dim_1c11": | |||||
C = inp.shape[1] | |||||
pshape = (1, C, 1, 1) | |||||
elif param_dim == "dim_111c": | |||||
C = inp.shape[3] | |||||
pshape = (1, 1, 1, C) | |||||
else: | |||||
raise ValueError("Invalid param_dim {}".format(param_dim)) | |||||
if x is None: | if x is None: | ||||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)() | (x,) = Const(value, dtype=inp.dtype, device=inp.device)() | ||||
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | ||||
(result,) = apply(builtin.Broadcast(), x, shape) | (result,) = apply(builtin.Broadcast(), x, shape) | ||||
return result | return result | ||||
elif x.ndim == 1: | |||||
else: | |||||
assert x_ndim == 1 | |||||
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | ||||
(result,) = apply(builtin.Reshape(), x, shape) | (result,) = apply(builtin.Reshape(), x, shape) | ||||
return result | return result | ||||
return x | |||||
has_mean = running_mean is not None | has_mean = running_mean is not None | ||||
has_var = running_var is not None | has_var = running_var is not None | ||||
@@ -1359,16 +1353,6 @@ def batch_norm( | |||||
assert has_mean, "running_mean must be provided in inference mode" | assert has_mean, "running_mean must be provided in inference mode" | ||||
assert has_var, "running_var must be provided in inference mode" | assert has_var, "running_var must be provided in inference mode" | ||||
if has_mean and running_mean.ndim != 4: | |||||
raise ValueError | |||||
if has_var and running_var.ndim != 4: | |||||
raise ValueError | |||||
if amp._enabled: | |||||
inp = inp.astype("float16") | |||||
weight, bias, running_mean, running_var = cast_tensors( | |||||
weight, bias, running_mean, running_var, promote=True | |||||
) | |||||
weight = make_full_if_none(weight, 1) | weight = make_full_if_none(weight, 1) | ||||
bias = make_full_if_none(bias, 0) | bias = make_full_if_none(bias, 0) | ||||
@@ -0,0 +1,34 @@ | |||||
from ..core.ops.special import Const | |||||
from ..jit.tracing import is_tracing | |||||
small_tensor_cache = {} | |||||
def _get_scalar_tensor_with_value(value, dtype=None, device=None): | |||||
global small_tensor_cache | |||||
if is_tracing(): | |||||
(ret,) = Const(value, dtype=dtype, device=device)() | |||||
else: | |||||
cache_key = (value, dtype, device) | |||||
if cache_key not in small_tensor_cache: | |||||
(ret,) = Const(value, dtype=dtype, device=device)() | |||||
small_tensor_cache[cache_key] = ret | |||||
else: | |||||
ret = small_tensor_cache[cache_key] | |||||
return ret | |||||
def get_scalar_zero(dtype=None, device=None): | |||||
return _get_scalar_tensor_with_value(0, dtype, device) | |||||
def get_scalar_zero_point_five(dtype=None, device=None): | |||||
return _get_scalar_tensor_with_value(0.5, dtype, device) | |||||
def get_scalar_one(dtype=None, device=None): | |||||
return _get_scalar_tensor_with_value(1, dtype, device) | |||||
def get_scalar_two(dtype=None, device=None): | |||||
return _get_scalar_tensor_with_value(2, dtype, device) |
@@ -15,6 +15,7 @@ | |||||
#include "megbrain/imperative/ops/backward_graph.h" | #include "megbrain/imperative/ops/backward_graph.h" | ||||
#include "megbrain/imperative/ops/utility.h" | #include "megbrain/imperative/ops/utility.h" | ||||
#include "megbrain/imperative/profiler.h" | #include "megbrain/imperative/profiler.h" | ||||
#include "megbrain/imperative/transformations/dtype_promote.h" | |||||
#include "megbrain/imperative/transformations/eval.h" | #include "megbrain/imperative/transformations/eval.h" | ||||
#include "megbrain/imperative/transformations/lazy.h" | #include "megbrain/imperative/transformations/lazy.h" | ||||
#include "megbrain/imperative/transformations/scalar.h" | #include "megbrain/imperative/transformations/scalar.h" | ||||
@@ -59,16 +60,19 @@ struct SymbolVarContext { | |||||
TransformationContext context; | TransformationContext context; | ||||
std::shared_ptr<SymbolTransformation> symbol_tsf; | std::shared_ptr<SymbolTransformation> symbol_tsf; | ||||
std::shared_ptr<ScalarTransformation> scalar_tsf; | std::shared_ptr<ScalarTransformation> scalar_tsf; | ||||
std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf; | |||||
SymbolVarContext(cg::ComputingGraph* graph) { | SymbolVarContext(cg::ComputingGraph* graph) { | ||||
symbol_tsf = std::make_shared<SymbolTransformation>(graph); | symbol_tsf = std::make_shared<SymbolTransformation>(graph); | ||||
scalar_tsf = std::make_shared<ScalarTransformation>(); | scalar_tsf = std::make_shared<ScalarTransformation>(); | ||||
dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>(); | |||||
Transformation::swap_context(context); | Transformation::swap_context(context); | ||||
} | } | ||||
void init() { | void init() { | ||||
symbol_tsf->register_at(Transformation::top()); | symbol_tsf->register_at(Transformation::top()); | ||||
scalar_tsf->register_at(Transformation::top()); | scalar_tsf->register_at(Transformation::top()); | ||||
dtype_promote_tsf->register_at(Transformation::top()); | |||||
} | } | ||||
ValueRef symvar2val(py::handle py_symbol_var) { | ValueRef symvar2val(py::handle py_symbol_var) { | ||||
@@ -110,6 +114,9 @@ REGISTE_APPLY_FUNC(cpp_astensor1d) | |||||
#undef REGISTE_APPLY_FUNC | #undef REGISTE_APPLY_FUNC | ||||
PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs); | |||||
CompNode _get_device(PyObject* const* args, size_t nargs); | |||||
PyObject* py_apply( | PyObject* py_apply( | ||||
PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) { | PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) { | ||||
try { | try { | ||||
@@ -133,19 +140,59 @@ PyObject* py_apply( | |||||
auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>(); | auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>(); | ||||
SmallVector<ValueRef, 8> tensors(nargs); | SmallVector<ValueRef, 8> tensors(nargs); | ||||
bool is_symbol_var = (!TensorWrapper::try_cast(args[0])) && | |||||
py::isinstance<PySymbolVar>(py::handle(args[0])); | |||||
if (is_symbol_var) { | |||||
SmallVector<bool, 8> is_symbol_var(nargs, false); | |||||
ComputingGraph* cg = nullptr; | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
if ((!TensorWrapper::try_cast(args[i])) && | |||||
py::isinstance<PySymbolVar>(py::handle(args[i]))) { | |||||
is_symbol_var[i] = true; | |||||
ComputingGraph* cur_cg = | |||||
py::handle(args[i]).cast<PySymbolVar*>()->m_node->owner_graph(); | |||||
if (cg == nullptr) { | |||||
cg = cur_cg; | |||||
} else { | |||||
mgb_assert(cg == cur_cg); | |||||
} | |||||
} | |||||
} | |||||
mgb::CompNode target_cn; | |||||
mgb::DType target_dtype; | |||||
auto convert_pyinput_to_tensor = [&](size_t i) -> ValueRef { | |||||
if (!target_dtype.valid()) { | |||||
target_dtype = npy::dtype_np2mgb_descr(_dtype_promotion(args, nargs)); | |||||
target_cn = _get_device(args, nargs); | |||||
} | |||||
HostTensorND ht(target_cn); | |||||
ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype); | |||||
if (PyArray_Check(args[i])) { // non scaler | |||||
return imperative::apply( | |||||
CreateTensor(CreateTensor::Const, target_cn, ht.layout()), | |||||
HostStorage::make(ht.storage()))[0]; | |||||
} else { // scaler | |||||
return imperative::apply( | |||||
CreateTensor(CreateTensor::Const, target_cn, target_dtype, {}), | |||||
HostStorage::make(ht.storage()))[0]; | |||||
} | |||||
}; | |||||
if (cg != nullptr) { | |||||
// swap to a special context to reuse scalar handle | // swap to a special context to reuse scalar handle | ||||
SymbolVarContext context( | |||||
py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph()); | |||||
size_t symbol_var_idx = 8; | |||||
SymbolVarContext context(cg); | |||||
context.init(); | context.init(); | ||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
tensors[i] = context.symvar2val(args[i]); | |||||
if (is_symbol_var[i]) { | |||||
symbol_var_idx = i; | |||||
tensors[i] = context.symvar2val(args[i]); | |||||
} else { | |||||
tensors[i] = convert_pyinput_to_tensor(i); | |||||
} | |||||
} | } | ||||
auto outputs = imperative::apply(*op, tensors); | auto outputs = imperative::apply(*op, tensors); | ||||
auto ret = pybind11::tuple(outputs.size()); | auto ret = pybind11::tuple(outputs.size()); | ||||
auto typeobj = py::handle(args[0]).get_type(); | |||||
auto typeobj = py::handle(args[symbol_var_idx]).get_type(); | |||||
for (size_t i = 0; i < outputs.size(); ++i) { | for (size_t i = 0; i < outputs.size(); ++i) { | ||||
ret[i] = context.val2symvar(typeobj, outputs[i]); | ret[i] = context.val2symvar(typeobj, outputs[i]); | ||||
} | } | ||||
@@ -156,13 +203,7 @@ PyObject* py_apply( | |||||
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | ||||
tensors[i] = tw->m_tensor->data(); | tensors[i] = tw->m_tensor->data(); | ||||
} else { | } else { | ||||
PyErr_SetString( | |||||
PyExc_TypeError, | |||||
ssprintf( | |||||
"op %s expect type Tensor as inputs, got %s actually", | |||||
op->make_name().c_str(), Py_TYPE(args[i])->tp_name) | |||||
.c_str()); | |||||
return nullptr; | |||||
tensors[i] = convert_pyinput_to_tensor(i); | |||||
} | } | ||||
} | } | ||||
@@ -616,6 +657,8 @@ void init_tensor(py::module m) { | |||||
std::shared_ptr<Channel>(channel, [](Channel*) {}))); | std::shared_ptr<Channel>(channel, [](Channel*) {}))); | ||||
transformations.register_at<Segment::Scalar>( | transformations.register_at<Segment::Scalar>( | ||||
std::make_shared<ScalarTransformation>()); | std::make_shared<ScalarTransformation>()); | ||||
transformations.register_at<Segment::DTypePromote>( | |||||
std::make_shared<DTypePromoteTransformation>()); | |||||
static py::exception<interpreter::AsyncError> py_async_error( | static py::exception<interpreter::AsyncError> py_async_error( | ||||
m, "AsyncError", PyExc_RuntimeError); | m, "AsyncError", PyExc_RuntimeError); | ||||
@@ -1137,6 +1180,63 @@ void init_tensor(py::module m) { | |||||
m.def("reset_stats", [] { imperative::Stats::reset(); }); | m.def("reset_stats", [] { imperative::Stats::reset(); }); | ||||
m.def("_get_convert_inputs", | |||||
[]() -> bool { return DTypePromoteCfg::convert_input_enabled; }); | |||||
m.def("_set_convert_inputs", [](bool flag) -> bool { | |||||
bool ret = DTypePromoteCfg::convert_input_enabled; | |||||
DTypePromoteCfg::convert_input_enabled = flag; | |||||
return ret; | |||||
}); | |||||
m.def("_get_amp_dtype_autocast", | |||||
[]() -> bool { return DTypePromoteCfg::amp_dtype_autocast_enabled; }); | |||||
m.def("_set_amp_dtype_autocast", [](bool flag) -> bool { | |||||
bool ret = DTypePromoteCfg::amp_dtype_autocast_enabled; | |||||
DTypePromoteCfg::amp_dtype_autocast_enabled = flag; | |||||
return ret; | |||||
}); | |||||
static auto get_amp_prec_dtype = [](bool is_high) -> std::string { | |||||
DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype | |||||
: DTypePromoteCfg::amp_low_prec_dtype; | |||||
mgb_assert(target.category() == DTypeCategory::FLOAT); | |||||
std::string ret = target.name(); | |||||
transform(ret.begin(), ret.end(), ret.begin(), ::tolower); | |||||
return ret; | |||||
}; | |||||
static auto set_amp_prec_dtype = [](bool is_high, | |||||
std::string dtype_name) -> std::string { | |||||
DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype | |||||
: DTypePromoteCfg::amp_low_prec_dtype; | |||||
std::string ret = target.name(); | |||||
if (dtype_name == "float32") { | |||||
target = dtype::Float32(); | |||||
} else if (dtype_name == "float16") { | |||||
target = dtype::Float16(); | |||||
} else if (dtype_name == "bfloat16") { | |||||
target = dtype::BFloat16(); | |||||
} else { | |||||
mgb_assert( | |||||
false, "casted type of amp should be float, but you give %s\n", | |||||
dtype_name.c_str()); | |||||
} | |||||
transform(ret.begin(), ret.end(), ret.begin(), ::tolower); | |||||
return ret; | |||||
}; | |||||
m.def("_get_amp_high_prec_dtype", | |||||
[]() -> std::string { return get_amp_prec_dtype(true); }); | |||||
m.def("_set_amp_high_prec_dtype", [](std::string dtype_name) -> std::string { | |||||
return set_amp_prec_dtype(true, dtype_name); | |||||
}); | |||||
m.def("_get_amp_low_prec_dtype", | |||||
[]() -> std::string { return get_amp_prec_dtype(false); }); | |||||
m.def("_set_amp_low_prec_dtype", [](std::string dtype_name) -> std::string { | |||||
return set_amp_prec_dtype(false, dtype_name); | |||||
}); | |||||
py::register_exception<TraceError>(m, "TraceError"); | py::register_exception<TraceError>(m, "TraceError"); | ||||
} | } | ||||
@@ -26,12 +26,13 @@ struct TransformationManager { | |||||
enum Segment { | enum Segment { | ||||
ModuleTrace, | ModuleTrace, | ||||
Grad, | Grad, | ||||
DTypePromote, | |||||
Scalar, | Scalar, | ||||
Trace, | Trace, | ||||
Eval, | Eval, | ||||
}; | }; | ||||
std::array<std::vector<std::shared_ptr<Transformation>>, 5> segments; | |||||
std::array<std::vector<std::shared_ptr<Transformation>>, 6> segments; | |||||
template <Segment segment> | template <Segment segment> | ||||
void register_at(std::shared_ptr<Transformation> transformation) { | void register_at(std::shared_ptr<Transformation> transformation) { | ||||
@@ -14,20 +14,20 @@ def test_grad_scaler(): | |||||
assert amp.enabled == enabled | assert amp.enabled == enabled | ||||
assert origin_amp._enabled == enabled | assert origin_amp._enabled == enabled | ||||
assert amp.low_prec_dtype == low | assert amp.low_prec_dtype == low | ||||
assert origin_amp._low_prec_dtype == low | |||||
assert origin_amp._get_amp_low_prec_dtype() == low | |||||
assert amp.high_prec_dtype == high | assert amp.high_prec_dtype == high | ||||
assert origin_amp._high_prec_dtype == high | |||||
assert origin_amp._get_amp_high_prec_dtype() == high | |||||
origin_enabled = amp.enabled | origin_enabled = amp.enabled | ||||
origin_high = amp.high_prec_dtype | origin_high = amp.high_prec_dtype | ||||
origin_low = amp.low_prec_dtype | origin_low = amp.low_prec_dtype | ||||
with amp.autocast(low_prec_dtype="low", high_prec_dtype="high"): | |||||
check(True, "low", "high") | |||||
with amp.autocast(low_prec_dtype="float16", high_prec_dtype="float32"): | |||||
check(True, "float16", "float32") | |||||
check(origin_enabled, origin_low, origin_high) | check(origin_enabled, origin_low, origin_high) | ||||
amp.enabled = True | amp.enabled = True | ||||
amp.high_prec_dtype = "high" | |||||
amp.low_prec_dtype = "low" | |||||
check(True, "low", "high") | |||||
amp.high_prec_dtype = "float32" | |||||
amp.low_prec_dtype = "float16" | |||||
check(True, "float16", "float32") | |||||
amp.enabled = origin_enabled | amp.enabled = origin_enabled | ||||
amp.high_prec_dtype = origin_high | amp.high_prec_dtype = origin_high | ||||
amp.low_prec_dtype = origin_low | amp.low_prec_dtype = origin_low | ||||
@@ -0,0 +1,251 @@ | |||||
#include "megbrain/imperative/transformations/dtype_promote.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
namespace mgb::imperative { | |||||
bool DTypePromoteCfg::convert_input_enabled = true; | |||||
bool DTypePromoteCfg::amp_dtype_autocast_enabled = false; | |||||
DType DTypePromoteCfg::amp_high_prec_dtype = dtype::Float32(); | |||||
DType DTypePromoteCfg::amp_low_prec_dtype = dtype::Float16(); | |||||
namespace { | |||||
// TODO: ScalarRule and DTypePromoteRule should be unified | |||||
using DTypePromoteRule = std::function<ValueRefList(const OpDef&, Span<ValueRef>)>; | |||||
static std::unordered_map<Typeinfo*, DTypePromoteRule> dtype_promotion_rules; | |||||
template <typename T> | |||||
void register_dtype_promote_rule(const DTypePromoteRule& rule) { | |||||
dtype_promotion_rules[T::typeinfo()] = [rule](const OpDef& def, | |||||
Span<ValueRef> inputs) { | |||||
return rule(def.cast_final_safe<T>(), inputs); | |||||
}; | |||||
} | |||||
bool is_quantized_dtype(const DType& dtype) { | |||||
return dtype.category() == DTypeCategory::QUANTIZED; | |||||
} | |||||
bool is_all_integer(const SmallVector<DType>& dtypes) { | |||||
for (size_t i = 0; i < dtypes.size(); ++i) { | |||||
if (dtypes[i].category() != DTypeCategory::INT) { | |||||
return false; | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
SmallVector<DType> get_value_dtypes(const Span<ValueRef> inputs) { | |||||
SmallVector<DType> dtypes(inputs.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
dtypes[i] = *(inputs[i].dtype()); | |||||
} | |||||
return dtypes; | |||||
} | |||||
mgb::DType get_promoted_dtype(const SmallVector<DType>& dtypes) { | |||||
if (dtypes.size() == 0) { | |||||
mgb_assert(false, "there is no input for operator, dtype promote failed"); | |||||
} | |||||
mgb::DType ret = dtypes[0]; | |||||
for (size_t i = 1; i < dtypes.size(); ++i) { | |||||
ret = mgb::dtype_promotion(ret, dtypes[i]); | |||||
} | |||||
return ret; | |||||
} | |||||
ValueRefList elemwise_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
auto&& elem_op = op.cast_final_safe<Elemwise>(); | |||||
SmallVector<DType> dtypes(inputs.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
dtypes[i] = *(inputs[i].dtype()); | |||||
} | |||||
ValueRefList converted(inputs.size()); | |||||
mgb::DType target_dtype = get_promoted_dtype(dtypes); | |||||
// TODO: we can save the dtypes of inputs here and perform TypeCvt at the end of | |||||
// this function, rather than perform TypeCvt eagerly. But for the compatibility, we | |||||
// implement this function with the similar process as the python version and | |||||
// perform TypeCvt here, so we maybe do TypeCvt several times in these function | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (!is_quantized_dtype(dtypes[i]) && dtypes[i] != target_dtype && | |||||
DTypePromoteCfg::convert_input_enabled) { | |||||
converted[i] = imperative::apply( | |||||
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||||
dtypes[i] = target_dtype; | |||||
} else { | |||||
converted[i] = inputs[i]; | |||||
} | |||||
} | |||||
static std::unordered_set<Elemwise::Mode> cast_case1 = { | |||||
Elemwise::Mode::TRUE_DIV, Elemwise::Mode::EXP, | |||||
Elemwise::Mode::POW, Elemwise::Mode::LOG, | |||||
Elemwise::Mode::EXPM1, Elemwise::Mode::LOG1P, | |||||
Elemwise::Mode::ACOS, Elemwise::Mode::ASIN, | |||||
Elemwise::Mode::ATAN2, Elemwise::Mode::COS, | |||||
Elemwise::Mode::SIN, Elemwise::Mode::LOG_SUM_EXP, | |||||
}; | |||||
static std::unordered_set<Elemwise::Mode> cast_case2 = { | |||||
Elemwise::Mode::TANH, | |||||
}; | |||||
auto cast_to_high_prec = [&]() { | |||||
for (size_t i = 0; i < dtypes.size(); ++i) { | |||||
if (dtypes[i] != DTypePromoteCfg::amp_high_prec_dtype) { | |||||
converted[i] = imperative::apply( | |||||
ApplyOp(*TypeCvt::make(DTypePromoteCfg::amp_high_prec_dtype)), | |||||
converted[i])[0]; | |||||
dtypes[i] = DTypePromoteCfg::amp_high_prec_dtype; | |||||
} | |||||
} | |||||
}; | |||||
if (cast_case1.find(elem_op.mode) != cast_case1.end()) { | |||||
if (DTypePromoteCfg::amp_dtype_autocast_enabled || is_all_integer(dtypes)) { | |||||
cast_to_high_prec(); | |||||
} | |||||
} | |||||
if (cast_case2.find(elem_op.mode) != cast_case2.end()) { | |||||
if (is_all_integer(dtypes)) { | |||||
cast_to_high_prec(); | |||||
} | |||||
} | |||||
static std::unordered_set<Elemwise::Mode> cast_case3 = { | |||||
Elemwise::Mode::CEIL, Elemwise::Mode::FLOOR, Elemwise::Mode::ROUND}; | |||||
if (cast_case3.find(elem_op.mode) != cast_case3.end()) { | |||||
if (dtypes[0].category() == DTypeCategory::INT) { | |||||
return converted; | |||||
} | |||||
} | |||||
return imperative::apply(op, converted); | |||||
} | |||||
ValueRefList reduce_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
auto&& reduce_op = op.cast_final_safe<Reduce>(); | |||||
DType org_dtype = *(inputs[0].dtype()); | |||||
DType target_dtype = org_dtype; | |||||
ValueRefList converted(inputs.begin(), inputs.end()); | |||||
if (reduce_op.mode == Reduce::Mode::MEAN) { | |||||
target_dtype = dtype::Float32(); | |||||
} else if (org_dtype.category() == DTypeCategory::BOOL) { | |||||
target_dtype = dtype::Int32(); | |||||
} | |||||
if (target_dtype != org_dtype) { | |||||
converted[0] = | |||||
imperative::apply(ApplyOp(*TypeCvt::make(target_dtype)), inputs[0])[0]; | |||||
} | |||||
ValueRefList ret = imperative::apply(op, converted); | |||||
if (org_dtype.category() == DTypeCategory::BOOL) { | |||||
if (reduce_op.mode == Reduce::Mode::MIN || | |||||
reduce_op.mode == Reduce::Mode::MAX) { | |||||
ret[0] = imperative::apply( | |||||
ApplyOp(*TypeCvt::make(dtype::Bool())), ret[0])[0]; | |||||
} | |||||
} | |||||
return ret; | |||||
} | |||||
ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
auto&& conv_op = const_cast<Convolution&>(op.cast_final_safe<Convolution>()); | |||||
SmallVector<DType> dtypes = get_value_dtypes(inputs); | |||||
mgb::DType target_dtype; | |||||
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { | |||||
conv_op.compute_mode = Convolution::ComputeMode::FLOAT32; | |||||
target_dtype = DTypePromoteCfg::amp_low_prec_dtype; | |||||
} else { | |||||
target_dtype = get_promoted_dtype(dtypes); | |||||
} | |||||
ValueRefList converted(inputs.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (dtypes[i] != target_dtype) { | |||||
converted[i] = imperative::apply( | |||||
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||||
} else { | |||||
converted[i] = inputs[i]; | |||||
} | |||||
} | |||||
return imperative::apply(op, converted); | |||||
} | |||||
ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { | |||||
mgb_assert(inputs.size() > 0); | |||||
ValueRefList converted(inputs.size()); | |||||
converted[0] = imperative::apply( | |||||
ApplyOp(*TypeCvt::make(dtype::Float16())), inputs[0])[0]; | |||||
for (size_t i = 1; i < inputs.size(); ++i) { | |||||
DType idtype = *(inputs[i].dtype()); | |||||
if (idtype != DTypePromoteCfg::amp_high_prec_dtype) { | |||||
converted[i] = imperative::apply( | |||||
ApplyOp(*TypeCvt::make(DTypePromoteCfg::amp_high_prec_dtype)), | |||||
inputs[i])[0]; | |||||
} else { | |||||
converted[i] = inputs[i]; | |||||
} | |||||
} | |||||
return imperative::apply(op, converted); | |||||
} | |||||
return imperative::apply(op, inputs); | |||||
} | |||||
struct DTypePromoteRuleRegistry { | |||||
DTypePromoteRuleRegistry() { | |||||
register_dtype_promote_rule<Elemwise>(elemwise_rule); | |||||
register_dtype_promote_rule<Reduce>(reduce_rule); | |||||
register_dtype_promote_rule<Convolution>(convolution_rule); | |||||
register_dtype_promote_rule<BatchNorm>(batch_norm_rule); | |||||
} | |||||
} register_helper; | |||||
} // namespace | |||||
ValueRefList DTypePromoteTransformation::apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) { | |||||
if (auto apply_op = op.as<ApplyOp>()) { | |||||
auto iter = dtype_promotion_rules.find(apply_op->op().dyn_typeinfo()); | |||||
if (iter != dtype_promotion_rules.end()) { | |||||
return iter->second(apply_op->op(), inputs); | |||||
} else { | |||||
return imperative::apply(op, inputs); | |||||
} | |||||
} | |||||
return imperative::apply(op, inputs); | |||||
} | |||||
ValueRef DTypePromoteTransformation::unwrap(ValueRef value) { | |||||
return value; | |||||
} | |||||
std::string DTypePromoteTransformation::name() const { | |||||
return "DTypePromoteTransformation"; | |||||
} | |||||
void DTypePromoteTransformation::on_register() { | |||||
// printf("DTypePromoteTransformation has been registered\n"); | |||||
} | |||||
void DTypePromoteTransformation::on_unregister() noexcept { | |||||
// printf("DTypePromoteTransformation has been unregistered\n"); | |||||
} | |||||
} // namespace mgb::imperative |
@@ -0,0 +1,26 @@ | |||||
#pragma once | |||||
#include "megbrain/imperative/dispatch.h" | |||||
#include "megbrain/imperative/value.h" | |||||
namespace mgb::imperative { | |||||
class DTypePromoteTransformation final : public Transformation { | |||||
private: | |||||
public: | |||||
ValueRefList apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) override; | |||||
ValueRef unwrap(ValueRef value) override; | |||||
std::string name() const override; | |||||
void on_register() override; | |||||
void on_unregister() noexcept override; | |||||
}; | |||||
struct DTypePromoteCfg { | |||||
static bool convert_input_enabled; | |||||
static bool amp_dtype_autocast_enabled; | |||||
static DType amp_high_prec_dtype; | |||||
static DType amp_low_prec_dtype; | |||||
}; | |||||
} // namespace mgb::imperative |