GitOrigin-RevId: a72f5460b6
tags/v1.9.0
@@ -15,9 +15,15 @@ import numpy as np | |||
from .. import _config | |||
from .._imperative_rt.common import CompNode | |||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion | |||
from .._imperative_rt.core2 import ( | |||
SymbolVar, | |||
Tensor, | |||
apply, | |||
broadcast_cpp, | |||
dtype_promotion, | |||
) | |||
from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | |||
from .._imperative_rt.core2 import squeeze_cpp, transpose_cpp | |||
from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp | |||
from ..ops import builtin | |||
from . import amp | |||
from .indexing import getitem, setitem | |||
@@ -331,70 +337,6 @@ def _matmul( | |||
return result | |||
def _broadcast(inp, shape): | |||
auto_infer = False | |||
if isinstance(shape, (list, tuple)): | |||
shape_tuple = list(shape) | |||
for i, s in enumerate(shape_tuple): | |||
if isinstance(s, type(None)): | |||
if s is None: | |||
right = i - len(shape_tuple) | |||
inp_shape = inp._tuple_shape | |||
if len(inp_shape) + right >= 0: | |||
shape_tuple[right] = list(inp_shape)[right] | |||
auto_infer = True | |||
continue | |||
else: | |||
raise ValueError("invalided Broadcast shape") | |||
else: | |||
raise ValueError( | |||
"expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format( | |||
i, s | |||
) | |||
) | |||
if s < 0: | |||
raise ValueError( | |||
"expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format( | |||
i, s | |||
) | |||
) | |||
if auto_infer: | |||
shape = tuple(shape_tuple) | |||
try: | |||
shape_tuple = make_shape_tuple(shape) | |||
except ValueError: | |||
shape_tuple = shape | |||
shape = astensor1d(shape_tuple, inp, dtype="int32", device=inp.device) | |||
(result,) = apply(builtin.Broadcast(), inp, shape) | |||
return result | |||
def _reshape(x, shape): | |||
unspec_axis = None | |||
try: | |||
shape_tuple = make_shape_tuple(shape) | |||
except ValueError: | |||
pass | |||
else: | |||
# XXX: assume unspec_axis is not changed in trace | |||
for i, s in enumerate(shape_tuple): | |||
if s < 0: | |||
if s != -1: | |||
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||
if unspec_axis is not None: | |||
raise ValueError( | |||
"multiple -1 in shape: {} & {}".format(unspec_axis, i) | |||
) | |||
unspec_axis = i | |||
shape = astensor1d(shape, x, dtype="int32", device=x.device) | |||
if unspec_axis is None: | |||
op = builtin.Reshape() | |||
else: | |||
op = builtin.Reshape(axis=unspec_axis) | |||
(x,) = apply(op, x, shape) | |||
return x | |||
def _unary_elwise(mode): | |||
def f(self): | |||
return _elwise(self, mode=mode) | |||
@@ -667,11 +609,11 @@ class ArrayMethodMixin(abc.ABC): | |||
def reshape(self, *args): | |||
r"""See :func:`~.reshape`.""" | |||
return _reshape(self, _expand_args(args)) | |||
return reshape_cpp(self, args) | |||
# FIXME: remove this method | |||
def _broadcast(self, *args): | |||
return _broadcast(self, _expand_args(args)) | |||
return broadcast_cpp(self, args) | |||
def transpose(self, *args): | |||
r"""See :func:`~.transpose`.""" | |||
@@ -679,7 +621,7 @@ class ArrayMethodMixin(abc.ABC): | |||
def flatten(self): | |||
r"""See :func:`~.flatten`.""" | |||
return self.reshape(-1) | |||
return reshape_cpp(self, (-1,)) | |||
def sum(self, axis=None, keepdims: bool = False): | |||
r"""Returns the sum of each row of the input tensor in the given dimension ``axis``. | |||
@@ -15,6 +15,7 @@ from ..core._imperative_rt import CompNode | |||
from ..core._imperative_rt.core2 import ( | |||
SymbolVar, | |||
apply, | |||
broadcast_cpp, | |||
dtype_promotion, | |||
expand_dims_cpp, | |||
split_cpp, | |||
@@ -24,7 +25,6 @@ from ..core._wrap import as_device | |||
from ..core.ops import builtin | |||
from ..core.ops.builtin import Copy, Identity | |||
from ..core.ops.special import Const | |||
from ..core.tensor.array_method import _broadcast | |||
from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn | |||
from ..device import get_default_device | |||
from ..tensor import Tensor | |||
@@ -360,7 +360,7 @@ def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||
[[0. 1. 2.] | |||
[0. 1. 2.]] | |||
""" | |||
return _broadcast(inp, shape) | |||
return broadcast_cpp(inp, shape) | |||
def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | |||
@@ -135,23 +135,24 @@ std::optional<ValueRefList> elemwise_grad_rule( | |||
std::optional<ValueRefList> reshape_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
mgb_assert(inputs.size() == 2); | |||
mgb_assert(inputs.size() == 1 || inputs.size() == 2); | |||
size_t nr_inp = inputs.size(); | |||
std::array<ValueRef, 2> input_shapes; | |||
for (size_t i = 0; i < 2; ++i) { | |||
for (size_t i = 0; i < nr_inp; ++i) { | |||
if (inputs_require_grad[i]) { | |||
input_shapes[i] = get_shape(inputs[i]); | |||
} | |||
} | |||
auto maker = CustomGradMaker(backward, inputs.size()); | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
SmallVector<ValueRef> ret(2); | |||
SmallVector<ValueRef> ret(nr_inp); | |||
if (!grad) { | |||
return ret; | |||
} | |||
for (size_t i = 0; i < 2; ++i) { | |||
for (size_t i = 0; i < nr_inp; ++i) { | |||
if (shapes[i]) { | |||
ret[i] = reshape_to(grad, shapes[i]); | |||
} | |||
@@ -162,6 +163,37 @@ std::optional<ValueRefList> reshape_grad_rule( | |||
return imperative::apply(ApplyOp(op), inputs); | |||
} | |||
std::optional<ValueRefList> broadcast_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
mgb_assert(inputs.size() == 1 || inputs.size() == 2); | |||
size_t nr_inp = inputs.size(); | |||
std::array<ValueRef, 2> input_shapes; | |||
for (size_t i = 0; i < nr_inp; ++i) { | |||
if (inputs_require_grad[i]) { | |||
input_shapes[i] = get_shape(inputs[i]); | |||
} | |||
} | |||
auto maker = CustomGradMaker(backward, inputs.size()); | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
SmallVector<ValueRef> ret(nr_inp); | |||
if (!grad) { | |||
return ret; | |||
} | |||
for (size_t i = 0; i < nr_inp; ++i) { | |||
if (shapes[i]) { | |||
ret[i] = reduce_to(grad, shapes[i]); | |||
} | |||
} | |||
return ret; | |||
}); | |||
maker.finalize(); | |||
return imperative::apply(ApplyOp(op), inputs); | |||
} | |||
std::optional<ValueRefList> subtensor_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
@@ -330,6 +362,7 @@ struct Init { | |||
Init() { | |||
CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule); | |||
CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule); | |||
CustomBackward::register_grad_rule(Broadcast::typeinfo(), broadcast_grad_rule); | |||
CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule); | |||
CustomBackward::register_grad_rule( | |||
IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); | |||
@@ -637,6 +637,8 @@ WRAP_FUNC_PY35(split_cpp); | |||
WRAP_FUNC_PY35(expand_dims_cpp); | |||
WRAP_FUNC_PY35(squeeze_cpp); | |||
WRAP_FUNC_PY35(transpose_cpp); | |||
WRAP_FUNC_PY35(broadcast_cpp); | |||
WRAP_FUNC_PY35(reshape_cpp); | |||
#undef WRAP_FUNC_PY35 | |||
#define MGE_PY_INTERFACE(NAME, FUNC) \ | |||
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | |||
@@ -773,6 +775,8 @@ void init_tensor(py::module m) { | |||
MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp), | |||
MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp), | |||
MGE_PY_INTERFACE(transpose_cpp, transpose_cpp), | |||
MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), | |||
MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), | |||
{nullptr, nullptr, 0, nullptr}}; | |||
for (auto&& def : method_defs) { | |||
if (def.ml_meth != nullptr) { | |||
@@ -800,29 +800,46 @@ size_t fast_ndim(py::handle tensor) { | |||
return getattr(tensor, "ndim").cast<size_t>(); | |||
} | |||
py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||
py::object _expand_args(py::handle args) { | |||
if (!PyTuple_Check(args.ptr())) { | |||
return py::reinterpret_borrow<py::object>(args); | |||
} | |||
py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr()); | |||
if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || | |||
is_tensor_or_symbolvar(args_tup[0].ptr()))) { | |||
return py::reinterpret_borrow<py::object>(args_tup[0]); | |||
} else { | |||
return py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr())); | |||
} | |||
} | |||
py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||
py::object obj = _expand_args(args); | |||
py::list lis; | |||
if (!is_tensor_or_symbolvar(obj.ptr()) && PySequence_Check(obj.ptr())) { | |||
lis = py::reinterpret_steal<py::list>(PySequence_List(obj.ptr())); | |||
} else { | |||
py::object np = getattr(obj, "numpy")(); | |||
PyArrayObject* arr = (PyArrayObject*)np.ptr(); | |||
PyObject* maybe_list = PyArray_ToList(arr); | |||
if (PyList_Check(maybe_list)) { | |||
lis = py::reinterpret_steal<py::list>(maybe_list); | |||
} | |||
} | |||
if (fast_ndim(inp_hdl) == 0) { | |||
if (args_tup.size() != 0) { | |||
if (lis.size() != 0) { | |||
throw py::index_error( | |||
"transpose for scalar does not accept additional args"); | |||
} | |||
return getattr(inp_hdl, "to")(getattr(inp_hdl, "device")); | |||
} | |||
std::vector<int32_t> pattern; | |||
if (!args_tup.size()) { | |||
if (!lis.size()) { | |||
size_t ndim = getattr(inp_hdl, "ndim").cast<size_t>(); | |||
for (size_t i = 0; i < ndim; ++i) { | |||
pattern.push_back(ndim - i - 1); | |||
} | |||
} else { | |||
py::list lis; | |||
if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || | |||
is_tensor_or_symbolvar(args_tup[0].ptr()))) { | |||
lis = py::reinterpret_steal<py::list>(PySequence_List(args_tup[0].ptr())); | |||
} else { | |||
lis = py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr())); | |||
} | |||
for (size_t i = 0; i < lis.size(); ++i) { | |||
if (PyLong_Check(lis[i].ptr())) { | |||
pattern.push_back(lis[i].cast<int32_t>()); | |||
@@ -844,6 +861,182 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||
return ret[0]; | |||
} | |||
std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) { | |||
std::vector<int32_t> shp; | |||
if (!PyTuple_Check(shape.ptr())) { | |||
return {shp, false}; | |||
} | |||
py::tuple tup = py::reinterpret_borrow<py::tuple>(shape); | |||
for (size_t i = 0; i < tup.size(); ++i) { | |||
if (!PyLong_Check(tup[i].ptr())) { | |||
return {shp, false}; | |||
} else { | |||
shp.push_back(tup[i].cast<int32_t>()); | |||
} | |||
} | |||
return {shp, true}; | |||
} | |||
bool enable_fastpath(py::handle inp) { | |||
if (!TensorWrapper::try_cast(inp.ptr()) || | |||
TransformationManager::get_instance() | |||
.segments[TransformationManager::Segment::Trace] | |||
.size() > 0 || | |||
TransformationManager::get_instance() | |||
.segments[TransformationManager::Segment::ModuleTrace] | |||
.size() > 0) { | |||
return false; | |||
} | |||
return true; | |||
} | |||
py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) { | |||
py::object shape_hdl = _expand_args(args); | |||
bool auto_infer = false; | |||
py::list lis; | |||
py::list new_shape; | |||
if (PyList_Check(shape_hdl.ptr()) || PyTuple_Check(shape_hdl.ptr())) { | |||
lis = py::reinterpret_steal<py::list>(PySequence_List(shape_hdl.ptr())); | |||
for (size_t i = 0; i < lis.size(); ++i) { | |||
if (lis[i].ptr() == Py_None) { | |||
auto_infer = true; | |||
size_t right = lis.size() - i; | |||
py::object tshp = getattr(inp_hdl, "_tuple_shape"); | |||
if (tshp.ptr() == Py_None) { | |||
throw py::index_error("does not support `None` with unknown shape"); | |||
} | |||
py::tuple inp_shape = py::reinterpret_borrow<py::tuple>(tshp); | |||
if (inp_shape.size() >= right) { | |||
if (enable_fastpath(inp_hdl)) { | |||
lis[i] = inp_shape[inp_shape.size() - right]; | |||
} | |||
new_shape.append(inp_shape[inp_shape.size() - right]); | |||
} else { | |||
throw py::value_error("invalid broadcast shape"); | |||
} | |||
} else { | |||
new_shape.append(lis[i]); | |||
if (PyLong_Check(lis[i].ptr())) { | |||
int32_t s = lis[i].cast<int32_t>(); | |||
if (s < 0) { | |||
throw py::value_error( | |||
"expect shape[" + std::to_string(i) + | |||
"] >= 0 or use `None` to auto infer, got " + | |||
std::to_string(s)); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
if (auto_infer) { | |||
if (enable_fastpath(inp_hdl)) { | |||
shape_hdl = py::reinterpret_borrow<py::tuple>(lis); | |||
} else { | |||
py::tuple args = py::make_tuple(new_shape, inp_hdl); | |||
py::dict kwargs; | |||
kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32()); | |||
kwargs["device"] = getattr(inp_hdl, "device"); | |||
shape_hdl = py::reinterpret_steal<py::object>( | |||
PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr())); | |||
} | |||
} | |||
py::object shape_tuple; | |||
try { | |||
shape_tuple = _make_shape_tuple(shape_hdl); | |||
} catch (py::error_already_set& err) { | |||
shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl); | |||
} | |||
auto [shape, fastpath] = tuple2vector(shape_tuple); | |||
fastpath &= enable_fastpath(inp_hdl); | |||
std::shared_ptr<OpDef> op; | |||
std::vector<PyObject*> p; | |||
py::object shape_tensor; | |||
if (fastpath) { | |||
op = Broadcast::make(shape); | |||
p.resize(2); | |||
} else { | |||
op = Broadcast::make(); | |||
py::tuple args = py::make_tuple(shape_hdl, inp_hdl); | |||
py::dict kwargs; | |||
kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32()); | |||
kwargs["device"] = getattr(inp_hdl, "device"); | |||
shape_tensor = py::reinterpret_steal<py::object>( | |||
PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr())); | |||
p.resize(3); | |||
p[2] = shape_tensor.ptr(); | |||
} | |||
py::object Op = py::cast(op); | |||
p[0] = Op.ptr(); | |||
p[1] = inp_hdl.ptr(); | |||
py::tuple ret = | |||
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||
return ret[0]; | |||
} | |||
py::object _reshape_cpp(py::handle inp_hdl, py::handle args) { | |||
py::object shape_hdl = _expand_args(args); | |||
py::object shape_tuple; | |||
try { | |||
shape_tuple = _make_shape_tuple(shape_hdl); | |||
} catch (py::error_already_set& err) { | |||
shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl); | |||
} | |||
int32_t unspec_axis = -1; | |||
if (PyTuple_Check(shape_tuple.ptr())) { | |||
py::tuple tup = py::reinterpret_borrow<py::tuple>(shape_tuple); | |||
for (size_t i = 0; i < tup.size(); ++i) { | |||
py::object obj = py::reinterpret_borrow<py::object>(tup[i]); | |||
if (obj < py::int_(0)) { | |||
if (obj.not_equal(py::int_(-1))) { | |||
throw py::value_error( | |||
"expect shape [" + std::to_string(i) + "] >= -1, got " + | |||
repr(obj).cast<std::string>()); | |||
} | |||
if (unspec_axis >= 0) { | |||
throw py::value_error( | |||
"multiple -1 in shape: " + std::to_string(unspec_axis) + | |||
" & " + std::to_string(i)); | |||
} | |||
unspec_axis = i; | |||
} | |||
} | |||
} | |||
auto [shape, fastpath] = tuple2vector(shape_tuple); | |||
fastpath &= enable_fastpath(inp_hdl); | |||
std::shared_ptr<OpDef> op; | |||
std::vector<PyObject*> p; | |||
py::object shape_tensor; | |||
if (fastpath) { | |||
if (unspec_axis >= 0) { | |||
op = Reshape::make(unspec_axis, shape); | |||
} else { | |||
op = Reshape::make(::megdnn::param::OptionalAxisV1::INVALID_AXIS, shape); | |||
} | |||
p.resize(2); | |||
} else { | |||
shape.clear(); | |||
if (unspec_axis >= 0) { | |||
op = Reshape::make(unspec_axis, shape); | |||
} else { | |||
op = Reshape::make(); | |||
} | |||
py::tuple args = py::make_tuple(shape_hdl, inp_hdl); | |||
py::dict kwargs; | |||
kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32()); | |||
kwargs["device"] = getattr(inp_hdl, "device"); | |||
shape_tensor = py::reinterpret_steal<py::object>( | |||
PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr())); | |||
p.resize(3); | |||
p[2] = shape_tensor.ptr(); | |||
} | |||
py::object Op = py::cast(op); | |||
p[0] = Op.ptr(); | |||
p[1] = inp_hdl.ptr(); | |||
py::tuple ret = | |||
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||
return ret[0]; | |||
} | |||
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | |||
try { | |||
return _make_shape_tuple(py::handle(args[0])).release().ptr(); | |||
@@ -900,4 +1093,18 @@ PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
try { | |||
return _broadcast_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); | |||
} | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
try { | |||
return _reshape_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); | |||
} | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
} // namespace mgb::imperative::python |
@@ -16,4 +16,8 @@ PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
} // namespace mgb::imperative::python |
@@ -267,7 +267,7 @@ def test_broadcast_auto_infer(is_varnode): | |||
F.broadcast_to(xx, (None, 1, 2, 3)) | |||
F.broadcast_to(xx, (1, None, 2, 3)) | |||
t = tensor(2, dtype=np.int32) | |||
t = make_tensor(2, network) | |||
F.broadcast_to(xx, (t, None, 2, 3)) | |||
@@ -51,57 +51,75 @@ bool valid_broadcast(const TensorShape& src_shape, const TensorShape& tar_shape) | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
auto&& op = def.cast_final_safe<Broadcast>(); | |||
size_t nr_inp = inputs.size(); | |||
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||
auto&& src = inputs[0]; | |||
auto&& tshp = inputs[1]; | |||
TensorShape out_shape; | |||
if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||
out_shape.ndim = 0; | |||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; | |||
} | |||
mgb_assert( | |||
tshp.layout.ndim == 1, | |||
"target shape of Broadcast expects ndim=1; got ndim=%lu actually", | |||
tshp.layout.ndim); | |||
size_t target_ndim = tshp.layout.shape[0]; | |||
out_shape.ndim = target_ndim; | |||
auto* ptr = tshp.value.ptr<dt_int32>(); | |||
for (size_t i = 0; i < target_ndim; ++i) { | |||
out_shape[i] = ptr[i]; | |||
if (nr_inp == 1) { | |||
out_shape.ndim = op.shape.size(); | |||
for (size_t i = 0; i < out_shape.ndim; ++i) { | |||
out_shape[i] = op.shape[i]; | |||
} | |||
} else { | |||
auto&& tshp = inputs[1]; | |||
if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||
out_shape.ndim = 0; | |||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||
false}; | |||
} | |||
mgb_assert( | |||
tshp.layout.ndim == 1, | |||
"target shape of Broadcast expects ndim=1; got ndim=%lu actually", | |||
tshp.layout.ndim); | |||
size_t target_ndim = tshp.layout.shape[0]; | |||
out_shape.ndim = target_ndim; | |||
auto* ptr = tshp.value.ptr<dt_int32>(); | |||
for (size_t i = 0; i < target_ndim; ++i) { | |||
out_shape[i] = ptr[i]; | |||
} | |||
} | |||
mgb_assert( | |||
valid_broadcast(src.layout, out_shape), | |||
"the input shape %s can not be broadcasted to target shape %s", | |||
src.layout.to_string().c_str(), out_shape.to_string().c_str()); | |||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
def.cast_final_safe<Broadcast>(); | |||
auto&& op = def.cast_final_safe<Broadcast>(); | |||
size_t nr_inp = inputs.size(); | |||
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||
TensorShape tshp; | |||
auto&& src = inputs[0]; | |||
auto&& tshp_nd = inputs[1]; | |||
auto slayout = src->layout(); | |||
TensorShape tshp; | |||
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||
if (nr_inp == 1) { | |||
tshp.ndim = op.shape.size(); | |||
for (size_t i = 0; i < tshp.ndim; ++i) { | |||
tshp[i] = op.shape[i]; | |||
} | |||
} else { | |||
auto&& tshp_nd = inputs[1]; | |||
cg::copy_tensor_value_to_shape( | |||
tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||
} | |||
TensorLayout tlayout = slayout.broadcast(tshp); | |||
// memory forward | |||
return {Tensor::make(src->blob(), src->offset(), tlayout)}; | |||
} | |||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||
return layout_checker; | |||
} | |||
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||
.get_input_layout_constraint(get_input_layout_constraint) | |||
.fallback(); | |||
} // namespace broadcast | |||
@@ -118,35 +136,49 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
auto&& op = def.cast_final_safe<Reshape>(); | |||
size_t nr_inp = inputs.size(); | |||
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); | |||
auto&& src = inputs[0]; | |||
auto&& tshp = inputs[1]; | |||
TensorShape out_shape; | |||
if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||
out_shape.ndim = 0; | |||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; | |||
} | |||
mgb_assert( | |||
tshp.layout.ndim == 1, | |||
"target shape of Reshape expects ndim=1; got ndim=%lu actually", | |||
tshp.layout.ndim); | |||
if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; | |||
} | |||
size_t target_ndim = tshp.layout.shape[0]; | |||
out_shape.ndim = target_ndim; | |||
auto* ptr = tshp.value.ptr<dt_int32>(); | |||
for (size_t i = 0; i < target_ndim; ++i) { | |||
out_shape[i] = ptr[i]; | |||
} | |||
if (src.layout.ndim == 0) { | |||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; | |||
if (nr_inp == 1) { | |||
if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||
false}; | |||
} | |||
out_shape.ndim = op.shape.size(); | |||
for (size_t i = 0; i < out_shape.ndim; ++i) { | |||
out_shape[i] = op.shape[i]; | |||
} | |||
if (src.layout.ndim == 0) { | |||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||
false}; | |||
} | |||
} else { | |||
auto&& tshp = inputs[1]; | |||
if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||
out_shape.ndim = 0; | |||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||
false}; | |||
} | |||
mgb_assert( | |||
tshp.layout.ndim == 1, | |||
"target shape of Reshape expects ndim=1; got ndim=%lu actually", | |||
tshp.layout.ndim); | |||
if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||
false}; | |||
} | |||
size_t target_ndim = tshp.layout.shape[0]; | |||
out_shape.ndim = target_ndim; | |||
auto* ptr = tshp.value.ptr<dt_int32>(); | |||
for (size_t i = 0; i < target_ndim; ++i) { | |||
out_shape[i] = ptr[i]; | |||
} | |||
if (src.layout.ndim == 0) { | |||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||
false}; | |||
} | |||
} | |||
if (op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
mgb_assert(out_shape[op.axis] == -1); | |||
out_shape[op.axis] = 1; | |||
@@ -167,19 +199,27 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto&& op_def = def.cast_final_safe<Reshape>(); | |||
auto&& op = def.cast_final_safe<Reshape>(); | |||
size_t nr_inp = inputs.size(); | |||
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); | |||
auto&& src = inputs[0]; | |||
auto&& tshp_nd = inputs[1]; | |||
auto slayout = src->layout(); | |||
TensorShape tshp; | |||
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
mgb_assert(tshp[op_def.axis] == -1); | |||
tshp[op_def.axis] = 1; | |||
tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); | |||
if (nr_inp == 1) { | |||
tshp.ndim = op.shape.size(); | |||
for (size_t i = 0; i < tshp.ndim; ++i) { | |||
tshp[i] = op.shape[i]; | |||
} | |||
} else { | |||
auto&& tshp_nd = inputs[1]; | |||
cg::copy_tensor_value_to_shape( | |||
tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||
} | |||
if (op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
mgb_assert(tshp[op.axis] == -1); | |||
tshp[op.axis] = 1; | |||
tshp[op.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); | |||
} | |||
TensorLayout tlayout; | |||
mgb_assert(slayout.try_reshape(tlayout, tshp)); | |||
@@ -188,17 +228,24 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
auto&& op_def = def.cast_final_safe<Reshape>(); | |||
auto&& op = def.cast_final_safe<Reshape>(); | |||
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||
layout_checker[0] = [&](const TensorLayout& layout) { | |||
TensorShape tshp; | |||
TensorLayout ret; | |||
cg::copy_tensor_value_to_shape( | |||
tshp, inputs[1]->get_value().proxy_to_default_cpu()); | |||
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
mgb_assert(tshp[op_def.axis] == -1); | |||
tshp[op_def.axis] = 1; | |||
tshp[op_def.axis] = layout.total_nr_elems() / tshp.total_nr_elems(); | |||
if (inputs.size() == 1) { | |||
tshp.ndim = op.shape.size(); | |||
for (size_t i = 0; i < tshp.ndim; ++i) { | |||
tshp[i] = op.shape[i]; | |||
} | |||
} else { | |||
cg::copy_tensor_value_to_shape( | |||
tshp, inputs[1]->get_value().proxy_to_default_cpu()); | |||
} | |||
if (op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
mgb_assert(tshp[op.axis] == -1); | |||
tshp[op.axis] = 1; | |||
tshp[op.axis] = layout.total_nr_elems() / tshp.total_nr_elems(); | |||
} | |||
if (layout.try_reshape(ret, tshp)) { | |||
return true; | |||
@@ -243,8 +243,10 @@ ValueRefList get_var_shape_rule( | |||
ValueRefList reshape_rule( | |||
const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
const Type<ScalarValue>& scalar_type) { | |||
mgb_assert(inputs.size() == 2); | |||
bool is_scalar = is_scalar_shape(inputs[1]); | |||
mgb_assert(inputs.size() == 1 || inputs.size() == 2); | |||
size_t nr_inp = inputs.size(); | |||
bool is_scalar = (nr_inp == 2 && is_scalar_shape(inputs[1])) || | |||
(nr_inp == 1 && reshape.shape.size() == 0); | |||
if (is_scalar) { | |||
return {scalar_type.make(imperative::apply( | |||
reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | |||
@@ -256,8 +258,10 @@ ValueRefList reshape_rule( | |||
ValueRefList broadcast_rule( | |||
const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
const Type<ScalarValue>& scalar_type) { | |||
mgb_assert(inputs.size() == 2); | |||
bool is_scalar = is_scalar_shape(inputs[1]); | |||
mgb_assert(inputs.size() == 1 || inputs.size() == 2); | |||
size_t nr_inp = inputs.size(); | |||
bool is_scalar = (nr_inp == 2 && is_scalar_shape(inputs[1])) || | |||
(nr_inp == 1 && broadcast.shape.size() == 0); | |||
if (is_scalar) { | |||
return {scalar_type.make(imperative::apply( | |||
broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | |||
@@ -250,7 +250,11 @@ def Concat: MgbHashableOp<"Concat", [AxisParam]> { | |||
); | |||
} | |||
def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]>; | |||
def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]> { | |||
let extraArguments = (ins | |||
MgbArrayAttr<MgbI32Attr>:$shape | |||
); | |||
} | |||
def Identity: MgbHashableOp<"Identity">; | |||
@@ -318,7 +322,11 @@ def Dimshuffle: MgbHashableOp<"Dimshuffle"> { | |||
let results = (outs AnyMemRef); | |||
} | |||
def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]>; | |||
def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]> { | |||
let extraArguments = (ins | |||
MgbArrayAttr<MgbI32Attr>:$shape | |||
); | |||
} | |||
// TODO: merge Add/Remove Axis into AxisAddRemove as megbrain? | |||
def AddAxis: MgbHashableOp<"AddAxis"> { | |||