GitOrigin-RevId: a72f5460b6
tags/v1.9.0
@@ -15,9 +15,15 @@ import numpy as np | |||||
from .. import _config | from .. import _config | ||||
from .._imperative_rt.common import CompNode | from .._imperative_rt.common import CompNode | ||||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion | |||||
from .._imperative_rt.core2 import ( | |||||
SymbolVar, | |||||
Tensor, | |||||
apply, | |||||
broadcast_cpp, | |||||
dtype_promotion, | |||||
) | |||||
from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | ||||
from .._imperative_rt.core2 import squeeze_cpp, transpose_cpp | |||||
from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp | |||||
from ..ops import builtin | from ..ops import builtin | ||||
from . import amp | from . import amp | ||||
from .indexing import getitem, setitem | from .indexing import getitem, setitem | ||||
@@ -331,70 +337,6 @@ def _matmul( | |||||
return result | return result | ||||
def _broadcast(inp, shape): | |||||
auto_infer = False | |||||
if isinstance(shape, (list, tuple)): | |||||
shape_tuple = list(shape) | |||||
for i, s in enumerate(shape_tuple): | |||||
if isinstance(s, type(None)): | |||||
if s is None: | |||||
right = i - len(shape_tuple) | |||||
inp_shape = inp._tuple_shape | |||||
if len(inp_shape) + right >= 0: | |||||
shape_tuple[right] = list(inp_shape)[right] | |||||
auto_infer = True | |||||
continue | |||||
else: | |||||
raise ValueError("invalided Broadcast shape") | |||||
else: | |||||
raise ValueError( | |||||
"expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format( | |||||
i, s | |||||
) | |||||
) | |||||
if s < 0: | |||||
raise ValueError( | |||||
"expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format( | |||||
i, s | |||||
) | |||||
) | |||||
if auto_infer: | |||||
shape = tuple(shape_tuple) | |||||
try: | |||||
shape_tuple = make_shape_tuple(shape) | |||||
except ValueError: | |||||
shape_tuple = shape | |||||
shape = astensor1d(shape_tuple, inp, dtype="int32", device=inp.device) | |||||
(result,) = apply(builtin.Broadcast(), inp, shape) | |||||
return result | |||||
def _reshape(x, shape): | |||||
unspec_axis = None | |||||
try: | |||||
shape_tuple = make_shape_tuple(shape) | |||||
except ValueError: | |||||
pass | |||||
else: | |||||
# XXX: assume unspec_axis is not changed in trace | |||||
for i, s in enumerate(shape_tuple): | |||||
if s < 0: | |||||
if s != -1: | |||||
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||||
if unspec_axis is not None: | |||||
raise ValueError( | |||||
"multiple -1 in shape: {} & {}".format(unspec_axis, i) | |||||
) | |||||
unspec_axis = i | |||||
shape = astensor1d(shape, x, dtype="int32", device=x.device) | |||||
if unspec_axis is None: | |||||
op = builtin.Reshape() | |||||
else: | |||||
op = builtin.Reshape(axis=unspec_axis) | |||||
(x,) = apply(op, x, shape) | |||||
return x | |||||
def _unary_elwise(mode): | def _unary_elwise(mode): | ||||
def f(self): | def f(self): | ||||
return _elwise(self, mode=mode) | return _elwise(self, mode=mode) | ||||
@@ -667,11 +609,11 @@ class ArrayMethodMixin(abc.ABC): | |||||
def reshape(self, *args): | def reshape(self, *args): | ||||
r"""See :func:`~.reshape`.""" | r"""See :func:`~.reshape`.""" | ||||
return _reshape(self, _expand_args(args)) | |||||
return reshape_cpp(self, args) | |||||
# FIXME: remove this method | # FIXME: remove this method | ||||
def _broadcast(self, *args): | def _broadcast(self, *args): | ||||
return _broadcast(self, _expand_args(args)) | |||||
return broadcast_cpp(self, args) | |||||
def transpose(self, *args): | def transpose(self, *args): | ||||
r"""See :func:`~.transpose`.""" | r"""See :func:`~.transpose`.""" | ||||
@@ -679,7 +621,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
def flatten(self): | def flatten(self): | ||||
r"""See :func:`~.flatten`.""" | r"""See :func:`~.flatten`.""" | ||||
return self.reshape(-1) | |||||
return reshape_cpp(self, (-1,)) | |||||
def sum(self, axis=None, keepdims: bool = False): | def sum(self, axis=None, keepdims: bool = False): | ||||
r"""Returns the sum of each row of the input tensor in the given dimension ``axis``. | r"""Returns the sum of each row of the input tensor in the given dimension ``axis``. | ||||
@@ -15,6 +15,7 @@ from ..core._imperative_rt import CompNode | |||||
from ..core._imperative_rt.core2 import ( | from ..core._imperative_rt.core2 import ( | ||||
SymbolVar, | SymbolVar, | ||||
apply, | apply, | ||||
broadcast_cpp, | |||||
dtype_promotion, | dtype_promotion, | ||||
expand_dims_cpp, | expand_dims_cpp, | ||||
split_cpp, | split_cpp, | ||||
@@ -24,7 +25,6 @@ from ..core._wrap import as_device | |||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops.builtin import Copy, Identity | from ..core.ops.builtin import Copy, Identity | ||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
from ..core.tensor.array_method import _broadcast | |||||
from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn | from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn | ||||
from ..device import get_default_device | from ..device import get_default_device | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
@@ -360,7 +360,7 @@ def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||||
[[0. 1. 2.] | [[0. 1. 2.] | ||||
[0. 1. 2.]] | [0. 1. 2.]] | ||||
""" | """ | ||||
return _broadcast(inp, shape) | |||||
return broadcast_cpp(inp, shape) | |||||
def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | ||||
@@ -135,23 +135,24 @@ std::optional<ValueRefList> elemwise_grad_rule( | |||||
std::optional<ValueRefList> reshape_grad_rule( | std::optional<ValueRefList> reshape_grad_rule( | ||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
CustomBackward& backward) { | CustomBackward& backward) { | ||||
mgb_assert(inputs.size() == 2); | |||||
mgb_assert(inputs.size() == 1 || inputs.size() == 2); | |||||
size_t nr_inp = inputs.size(); | |||||
std::array<ValueRef, 2> input_shapes; | std::array<ValueRef, 2> input_shapes; | ||||
for (size_t i = 0; i < 2; ++i) { | |||||
for (size_t i = 0; i < nr_inp; ++i) { | |||||
if (inputs_require_grad[i]) { | if (inputs_require_grad[i]) { | ||||
input_shapes[i] = get_shape(inputs[i]); | input_shapes[i] = get_shape(inputs[i]); | ||||
} | } | ||||
} | } | ||||
auto maker = CustomGradMaker(backward, inputs.size()); | auto maker = CustomGradMaker(backward, inputs.size()); | ||||
maker.output_size(1).output_captured(0, false); | maker.output_size(1).output_captured(0, false); | ||||
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||||
maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) { | |||||
mgb_assert(grads.size() == 1); | mgb_assert(grads.size() == 1); | ||||
ValueRef grad = grads[0]; | ValueRef grad = grads[0]; | ||||
SmallVector<ValueRef> ret(2); | |||||
SmallVector<ValueRef> ret(nr_inp); | |||||
if (!grad) { | if (!grad) { | ||||
return ret; | return ret; | ||||
} | } | ||||
for (size_t i = 0; i < 2; ++i) { | |||||
for (size_t i = 0; i < nr_inp; ++i) { | |||||
if (shapes[i]) { | if (shapes[i]) { | ||||
ret[i] = reshape_to(grad, shapes[i]); | ret[i] = reshape_to(grad, shapes[i]); | ||||
} | } | ||||
@@ -162,6 +163,37 @@ std::optional<ValueRefList> reshape_grad_rule( | |||||
return imperative::apply(ApplyOp(op), inputs); | return imperative::apply(ApplyOp(op), inputs); | ||||
} | } | ||||
std::optional<ValueRefList> broadcast_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
CustomBackward& backward) { | |||||
mgb_assert(inputs.size() == 1 || inputs.size() == 2); | |||||
size_t nr_inp = inputs.size(); | |||||
std::array<ValueRef, 2> input_shapes; | |||||
for (size_t i = 0; i < nr_inp; ++i) { | |||||
if (inputs_require_grad[i]) { | |||||
input_shapes[i] = get_shape(inputs[i]); | |||||
} | |||||
} | |||||
auto maker = CustomGradMaker(backward, inputs.size()); | |||||
maker.output_size(1).output_captured(0, false); | |||||
maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) { | |||||
mgb_assert(grads.size() == 1); | |||||
ValueRef grad = grads[0]; | |||||
SmallVector<ValueRef> ret(nr_inp); | |||||
if (!grad) { | |||||
return ret; | |||||
} | |||||
for (size_t i = 0; i < nr_inp; ++i) { | |||||
if (shapes[i]) { | |||||
ret[i] = reduce_to(grad, shapes[i]); | |||||
} | |||||
} | |||||
return ret; | |||||
}); | |||||
maker.finalize(); | |||||
return imperative::apply(ApplyOp(op), inputs); | |||||
} | |||||
std::optional<ValueRefList> subtensor_grad_rule( | std::optional<ValueRefList> subtensor_grad_rule( | ||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
CustomBackward& backward) { | CustomBackward& backward) { | ||||
@@ -330,6 +362,7 @@ struct Init { | |||||
Init() { | Init() { | ||||
CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule); | CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule); | ||||
CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule); | CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule); | ||||
CustomBackward::register_grad_rule(Broadcast::typeinfo(), broadcast_grad_rule); | |||||
CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule); | CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule); | ||||
CustomBackward::register_grad_rule( | CustomBackward::register_grad_rule( | ||||
IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); | IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); | ||||
@@ -637,6 +637,8 @@ WRAP_FUNC_PY35(split_cpp); | |||||
WRAP_FUNC_PY35(expand_dims_cpp); | WRAP_FUNC_PY35(expand_dims_cpp); | ||||
WRAP_FUNC_PY35(squeeze_cpp); | WRAP_FUNC_PY35(squeeze_cpp); | ||||
WRAP_FUNC_PY35(transpose_cpp); | WRAP_FUNC_PY35(transpose_cpp); | ||||
WRAP_FUNC_PY35(broadcast_cpp); | |||||
WRAP_FUNC_PY35(reshape_cpp); | |||||
#undef WRAP_FUNC_PY35 | #undef WRAP_FUNC_PY35 | ||||
#define MGE_PY_INTERFACE(NAME, FUNC) \ | #define MGE_PY_INTERFACE(NAME, FUNC) \ | ||||
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | ||||
@@ -773,6 +775,8 @@ void init_tensor(py::module m) { | |||||
MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp), | MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp), | ||||
MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp), | MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp), | ||||
MGE_PY_INTERFACE(transpose_cpp, transpose_cpp), | MGE_PY_INTERFACE(transpose_cpp, transpose_cpp), | ||||
MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), | |||||
MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), | |||||
{nullptr, nullptr, 0, nullptr}}; | {nullptr, nullptr, 0, nullptr}}; | ||||
for (auto&& def : method_defs) { | for (auto&& def : method_defs) { | ||||
if (def.ml_meth != nullptr) { | if (def.ml_meth != nullptr) { | ||||
@@ -800,29 +800,46 @@ size_t fast_ndim(py::handle tensor) { | |||||
return getattr(tensor, "ndim").cast<size_t>(); | return getattr(tensor, "ndim").cast<size_t>(); | ||||
} | } | ||||
py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||||
py::object _expand_args(py::handle args) { | |||||
if (!PyTuple_Check(args.ptr())) { | |||||
return py::reinterpret_borrow<py::object>(args); | |||||
} | |||||
py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr()); | py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr()); | ||||
if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || | |||||
is_tensor_or_symbolvar(args_tup[0].ptr()))) { | |||||
return py::reinterpret_borrow<py::object>(args_tup[0]); | |||||
} else { | |||||
return py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr())); | |||||
} | |||||
} | |||||
py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||||
py::object obj = _expand_args(args); | |||||
py::list lis; | |||||
if (!is_tensor_or_symbolvar(obj.ptr()) && PySequence_Check(obj.ptr())) { | |||||
lis = py::reinterpret_steal<py::list>(PySequence_List(obj.ptr())); | |||||
} else { | |||||
py::object np = getattr(obj, "numpy")(); | |||||
PyArrayObject* arr = (PyArrayObject*)np.ptr(); | |||||
PyObject* maybe_list = PyArray_ToList(arr); | |||||
if (PyList_Check(maybe_list)) { | |||||
lis = py::reinterpret_steal<py::list>(maybe_list); | |||||
} | |||||
} | |||||
if (fast_ndim(inp_hdl) == 0) { | if (fast_ndim(inp_hdl) == 0) { | ||||
if (args_tup.size() != 0) { | |||||
if (lis.size() != 0) { | |||||
throw py::index_error( | throw py::index_error( | ||||
"transpose for scalar does not accept additional args"); | "transpose for scalar does not accept additional args"); | ||||
} | } | ||||
return getattr(inp_hdl, "to")(getattr(inp_hdl, "device")); | return getattr(inp_hdl, "to")(getattr(inp_hdl, "device")); | ||||
} | } | ||||
std::vector<int32_t> pattern; | std::vector<int32_t> pattern; | ||||
if (!args_tup.size()) { | |||||
if (!lis.size()) { | |||||
size_t ndim = getattr(inp_hdl, "ndim").cast<size_t>(); | size_t ndim = getattr(inp_hdl, "ndim").cast<size_t>(); | ||||
for (size_t i = 0; i < ndim; ++i) { | for (size_t i = 0; i < ndim; ++i) { | ||||
pattern.push_back(ndim - i - 1); | pattern.push_back(ndim - i - 1); | ||||
} | } | ||||
} else { | } else { | ||||
py::list lis; | |||||
if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || | |||||
is_tensor_or_symbolvar(args_tup[0].ptr()))) { | |||||
lis = py::reinterpret_steal<py::list>(PySequence_List(args_tup[0].ptr())); | |||||
} else { | |||||
lis = py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr())); | |||||
} | |||||
for (size_t i = 0; i < lis.size(); ++i) { | for (size_t i = 0; i < lis.size(); ++i) { | ||||
if (PyLong_Check(lis[i].ptr())) { | if (PyLong_Check(lis[i].ptr())) { | ||||
pattern.push_back(lis[i].cast<int32_t>()); | pattern.push_back(lis[i].cast<int32_t>()); | ||||
@@ -844,6 +861,182 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||||
return ret[0]; | return ret[0]; | ||||
} | } | ||||
std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) { | |||||
std::vector<int32_t> shp; | |||||
if (!PyTuple_Check(shape.ptr())) { | |||||
return {shp, false}; | |||||
} | |||||
py::tuple tup = py::reinterpret_borrow<py::tuple>(shape); | |||||
for (size_t i = 0; i < tup.size(); ++i) { | |||||
if (!PyLong_Check(tup[i].ptr())) { | |||||
return {shp, false}; | |||||
} else { | |||||
shp.push_back(tup[i].cast<int32_t>()); | |||||
} | |||||
} | |||||
return {shp, true}; | |||||
} | |||||
bool enable_fastpath(py::handle inp) { | |||||
if (!TensorWrapper::try_cast(inp.ptr()) || | |||||
TransformationManager::get_instance() | |||||
.segments[TransformationManager::Segment::Trace] | |||||
.size() > 0 || | |||||
TransformationManager::get_instance() | |||||
.segments[TransformationManager::Segment::ModuleTrace] | |||||
.size() > 0) { | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) { | |||||
py::object shape_hdl = _expand_args(args); | |||||
bool auto_infer = false; | |||||
py::list lis; | |||||
py::list new_shape; | |||||
if (PyList_Check(shape_hdl.ptr()) || PyTuple_Check(shape_hdl.ptr())) { | |||||
lis = py::reinterpret_steal<py::list>(PySequence_List(shape_hdl.ptr())); | |||||
for (size_t i = 0; i < lis.size(); ++i) { | |||||
if (lis[i].ptr() == Py_None) { | |||||
auto_infer = true; | |||||
size_t right = lis.size() - i; | |||||
py::object tshp = getattr(inp_hdl, "_tuple_shape"); | |||||
if (tshp.ptr() == Py_None) { | |||||
throw py::index_error("does not support `None` with unknown shape"); | |||||
} | |||||
py::tuple inp_shape = py::reinterpret_borrow<py::tuple>(tshp); | |||||
if (inp_shape.size() >= right) { | |||||
if (enable_fastpath(inp_hdl)) { | |||||
lis[i] = inp_shape[inp_shape.size() - right]; | |||||
} | |||||
new_shape.append(inp_shape[inp_shape.size() - right]); | |||||
} else { | |||||
throw py::value_error("invalid broadcast shape"); | |||||
} | |||||
} else { | |||||
new_shape.append(lis[i]); | |||||
if (PyLong_Check(lis[i].ptr())) { | |||||
int32_t s = lis[i].cast<int32_t>(); | |||||
if (s < 0) { | |||||
throw py::value_error( | |||||
"expect shape[" + std::to_string(i) + | |||||
"] >= 0 or use `None` to auto infer, got " + | |||||
std::to_string(s)); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
if (auto_infer) { | |||||
if (enable_fastpath(inp_hdl)) { | |||||
shape_hdl = py::reinterpret_borrow<py::tuple>(lis); | |||||
} else { | |||||
py::tuple args = py::make_tuple(new_shape, inp_hdl); | |||||
py::dict kwargs; | |||||
kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32()); | |||||
kwargs["device"] = getattr(inp_hdl, "device"); | |||||
shape_hdl = py::reinterpret_steal<py::object>( | |||||
PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr())); | |||||
} | |||||
} | |||||
py::object shape_tuple; | |||||
try { | |||||
shape_tuple = _make_shape_tuple(shape_hdl); | |||||
} catch (py::error_already_set& err) { | |||||
shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl); | |||||
} | |||||
auto [shape, fastpath] = tuple2vector(shape_tuple); | |||||
fastpath &= enable_fastpath(inp_hdl); | |||||
std::shared_ptr<OpDef> op; | |||||
std::vector<PyObject*> p; | |||||
py::object shape_tensor; | |||||
if (fastpath) { | |||||
op = Broadcast::make(shape); | |||||
p.resize(2); | |||||
} else { | |||||
op = Broadcast::make(); | |||||
py::tuple args = py::make_tuple(shape_hdl, inp_hdl); | |||||
py::dict kwargs; | |||||
kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32()); | |||||
kwargs["device"] = getattr(inp_hdl, "device"); | |||||
shape_tensor = py::reinterpret_steal<py::object>( | |||||
PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr())); | |||||
p.resize(3); | |||||
p[2] = shape_tensor.ptr(); | |||||
} | |||||
py::object Op = py::cast(op); | |||||
p[0] = Op.ptr(); | |||||
p[1] = inp_hdl.ptr(); | |||||
py::tuple ret = | |||||
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||||
return ret[0]; | |||||
} | |||||
py::object _reshape_cpp(py::handle inp_hdl, py::handle args) { | |||||
py::object shape_hdl = _expand_args(args); | |||||
py::object shape_tuple; | |||||
try { | |||||
shape_tuple = _make_shape_tuple(shape_hdl); | |||||
} catch (py::error_already_set& err) { | |||||
shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl); | |||||
} | |||||
int32_t unspec_axis = -1; | |||||
if (PyTuple_Check(shape_tuple.ptr())) { | |||||
py::tuple tup = py::reinterpret_borrow<py::tuple>(shape_tuple); | |||||
for (size_t i = 0; i < tup.size(); ++i) { | |||||
py::object obj = py::reinterpret_borrow<py::object>(tup[i]); | |||||
if (obj < py::int_(0)) { | |||||
if (obj.not_equal(py::int_(-1))) { | |||||
throw py::value_error( | |||||
"expect shape [" + std::to_string(i) + "] >= -1, got " + | |||||
repr(obj).cast<std::string>()); | |||||
} | |||||
if (unspec_axis >= 0) { | |||||
throw py::value_error( | |||||
"multiple -1 in shape: " + std::to_string(unspec_axis) + | |||||
" & " + std::to_string(i)); | |||||
} | |||||
unspec_axis = i; | |||||
} | |||||
} | |||||
} | |||||
auto [shape, fastpath] = tuple2vector(shape_tuple); | |||||
fastpath &= enable_fastpath(inp_hdl); | |||||
std::shared_ptr<OpDef> op; | |||||
std::vector<PyObject*> p; | |||||
py::object shape_tensor; | |||||
if (fastpath) { | |||||
if (unspec_axis >= 0) { | |||||
op = Reshape::make(unspec_axis, shape); | |||||
} else { | |||||
op = Reshape::make(::megdnn::param::OptionalAxisV1::INVALID_AXIS, shape); | |||||
} | |||||
p.resize(2); | |||||
} else { | |||||
shape.clear(); | |||||
if (unspec_axis >= 0) { | |||||
op = Reshape::make(unspec_axis, shape); | |||||
} else { | |||||
op = Reshape::make(); | |||||
} | |||||
py::tuple args = py::make_tuple(shape_hdl, inp_hdl); | |||||
py::dict kwargs; | |||||
kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32()); | |||||
kwargs["device"] = getattr(inp_hdl, "device"); | |||||
shape_tensor = py::reinterpret_steal<py::object>( | |||||
PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr())); | |||||
p.resize(3); | |||||
p[2] = shape_tensor.ptr(); | |||||
} | |||||
py::object Op = py::cast(op); | |||||
p[0] = Op.ptr(); | |||||
p[1] = inp_hdl.ptr(); | |||||
py::tuple ret = | |||||
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||||
return ret[0]; | |||||
} | |||||
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | ||||
try { | try { | ||||
return _make_shape_tuple(py::handle(args[0])).release().ptr(); | return _make_shape_tuple(py::handle(args[0])).release().ptr(); | ||||
@@ -900,4 +1093,18 @@ PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | PYEXT17_TRANSLATE_EXC_RET(nullptr) | ||||
} | } | ||||
PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
try { | |||||
return _broadcast_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); | |||||
} | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | |||||
PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
try { | |||||
return _reshape_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); | |||||
} | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python |
@@ -16,4 +16,8 @@ PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs); | ||||
PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python |
@@ -267,7 +267,7 @@ def test_broadcast_auto_infer(is_varnode): | |||||
F.broadcast_to(xx, (None, 1, 2, 3)) | F.broadcast_to(xx, (None, 1, 2, 3)) | ||||
F.broadcast_to(xx, (1, None, 2, 3)) | F.broadcast_to(xx, (1, None, 2, 3)) | ||||
t = tensor(2, dtype=np.int32) | |||||
t = make_tensor(2, network) | |||||
F.broadcast_to(xx, (t, None, 2, 3)) | F.broadcast_to(xx, (t, None, 2, 3)) | ||||
@@ -51,57 +51,75 @@ bool valid_broadcast(const TensorShape& src_shape, const TensorShape& tar_shape) | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | ||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | ||||
auto&& op = def.cast_final_safe<Broadcast>(); | |||||
size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||||
auto&& src = inputs[0]; | auto&& src = inputs[0]; | ||||
auto&& tshp = inputs[1]; | |||||
TensorShape out_shape; | TensorShape out_shape; | ||||
if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||||
out_shape.ndim = 0; | |||||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; | |||||
} | |||||
mgb_assert( | |||||
tshp.layout.ndim == 1, | |||||
"target shape of Broadcast expects ndim=1; got ndim=%lu actually", | |||||
tshp.layout.ndim); | |||||
size_t target_ndim = tshp.layout.shape[0]; | |||||
out_shape.ndim = target_ndim; | |||||
auto* ptr = tshp.value.ptr<dt_int32>(); | |||||
for (size_t i = 0; i < target_ndim; ++i) { | |||||
out_shape[i] = ptr[i]; | |||||
if (nr_inp == 1) { | |||||
out_shape.ndim = op.shape.size(); | |||||
for (size_t i = 0; i < out_shape.ndim; ++i) { | |||||
out_shape[i] = op.shape[i]; | |||||
} | |||||
} else { | |||||
auto&& tshp = inputs[1]; | |||||
if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||||
out_shape.ndim = 0; | |||||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||||
false}; | |||||
} | |||||
mgb_assert( | |||||
tshp.layout.ndim == 1, | |||||
"target shape of Broadcast expects ndim=1; got ndim=%lu actually", | |||||
tshp.layout.ndim); | |||||
size_t target_ndim = tshp.layout.shape[0]; | |||||
out_shape.ndim = target_ndim; | |||||
auto* ptr = tshp.value.ptr<dt_int32>(); | |||||
for (size_t i = 0; i < target_ndim; ++i) { | |||||
out_shape[i] = ptr[i]; | |||||
} | |||||
} | } | ||||
mgb_assert( | mgb_assert( | ||||
valid_broadcast(src.layout, out_shape), | valid_broadcast(src.layout, out_shape), | ||||
"the input shape %s can not be broadcasted to target shape %s", | "the input shape %s can not be broadcasted to target shape %s", | ||||
src.layout.to_string().c_str(), out_shape.to_string().c_str()); | src.layout.to_string().c_str(), out_shape.to_string().c_str()); | ||||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; | return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; | ||||
} | } | ||||
SmallVector<TensorPtr> apply_on_physical_tensor( | SmallVector<TensorPtr> apply_on_physical_tensor( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | const OpDef& def, const SmallVector<TensorPtr>& inputs, | ||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | ||||
def.cast_final_safe<Broadcast>(); | |||||
auto&& op = def.cast_final_safe<Broadcast>(); | |||||
size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||||
TensorShape tshp; | |||||
auto&& src = inputs[0]; | auto&& src = inputs[0]; | ||||
auto&& tshp_nd = inputs[1]; | |||||
auto slayout = src->layout(); | auto slayout = src->layout(); | ||||
TensorShape tshp; | |||||
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||||
if (nr_inp == 1) { | |||||
tshp.ndim = op.shape.size(); | |||||
for (size_t i = 0; i < tshp.ndim; ++i) { | |||||
tshp[i] = op.shape[i]; | |||||
} | |||||
} else { | |||||
auto&& tshp_nd = inputs[1]; | |||||
cg::copy_tensor_value_to_shape( | |||||
tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||||
} | |||||
TensorLayout tlayout = slayout.broadcast(tshp); | TensorLayout tlayout = slayout.broadcast(tshp); | ||||
// memory forward | // memory forward | ||||
return {Tensor::make(src->blob(), src->offset(), tlayout)}; | return {Tensor::make(src->blob(), src->offset(), tlayout)}; | ||||
} | } | ||||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||||
return layout_checker; | |||||
} | |||||
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | ||||
.make_from_op_node(make_from_op_node) | .make_from_op_node(make_from_op_node) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
.apply_on_physical_tensor(apply_on_physical_tensor) | .apply_on_physical_tensor(apply_on_physical_tensor) | ||||
.get_input_layout_constraint(get_input_layout_constraint) | |||||
.fallback(); | .fallback(); | ||||
} // namespace broadcast | } // namespace broadcast | ||||
@@ -118,35 +136,49 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | ||||
auto&& op = def.cast_final_safe<Reshape>(); | auto&& op = def.cast_final_safe<Reshape>(); | ||||
size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); | |||||
auto&& src = inputs[0]; | auto&& src = inputs[0]; | ||||
auto&& tshp = inputs[1]; | |||||
TensorShape out_shape; | TensorShape out_shape; | ||||
if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||||
out_shape.ndim = 0; | |||||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; | |||||
} | |||||
mgb_assert( | |||||
tshp.layout.ndim == 1, | |||||
"target shape of Reshape expects ndim=1; got ndim=%lu actually", | |||||
tshp.layout.ndim); | |||||
if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; | |||||
} | |||||
size_t target_ndim = tshp.layout.shape[0]; | |||||
out_shape.ndim = target_ndim; | |||||
auto* ptr = tshp.value.ptr<dt_int32>(); | |||||
for (size_t i = 0; i < target_ndim; ++i) { | |||||
out_shape[i] = ptr[i]; | |||||
} | |||||
if (src.layout.ndim == 0) { | |||||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; | |||||
if (nr_inp == 1) { | |||||
if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||||
false}; | |||||
} | |||||
out_shape.ndim = op.shape.size(); | |||||
for (size_t i = 0; i < out_shape.ndim; ++i) { | |||||
out_shape[i] = op.shape[i]; | |||||
} | |||||
if (src.layout.ndim == 0) { | |||||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||||
false}; | |||||
} | |||||
} else { | |||||
auto&& tshp = inputs[1]; | |||||
if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||||
out_shape.ndim = 0; | |||||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||||
false}; | |||||
} | |||||
mgb_assert( | |||||
tshp.layout.ndim == 1, | |||||
"target shape of Reshape expects ndim=1; got ndim=%lu actually", | |||||
tshp.layout.ndim); | |||||
if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||||
false}; | |||||
} | |||||
size_t target_ndim = tshp.layout.shape[0]; | |||||
out_shape.ndim = target_ndim; | |||||
auto* ptr = tshp.value.ptr<dt_int32>(); | |||||
for (size_t i = 0; i < target_ndim; ++i) { | |||||
out_shape[i] = ptr[i]; | |||||
} | |||||
if (src.layout.ndim == 0) { | |||||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||||
false}; | |||||
} | |||||
} | } | ||||
if (op.axis != opr::Reshape::Param::INVALID_AXIS) { | if (op.axis != opr::Reshape::Param::INVALID_AXIS) { | ||||
mgb_assert(out_shape[op.axis] == -1); | mgb_assert(out_shape[op.axis] == -1); | ||||
out_shape[op.axis] = 1; | out_shape[op.axis] = 1; | ||||
@@ -167,19 +199,27 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | SmallVector<TensorPtr> apply_on_physical_tensor( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | const OpDef& def, const SmallVector<TensorPtr>& inputs, | ||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | ||||
auto&& op_def = def.cast_final_safe<Reshape>(); | |||||
auto&& op = def.cast_final_safe<Reshape>(); | |||||
size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); | |||||
auto&& src = inputs[0]; | auto&& src = inputs[0]; | ||||
auto&& tshp_nd = inputs[1]; | |||||
auto slayout = src->layout(); | auto slayout = src->layout(); | ||||
TensorShape tshp; | TensorShape tshp; | ||||
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||||
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { | |||||
mgb_assert(tshp[op_def.axis] == -1); | |||||
tshp[op_def.axis] = 1; | |||||
tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); | |||||
if (nr_inp == 1) { | |||||
tshp.ndim = op.shape.size(); | |||||
for (size_t i = 0; i < tshp.ndim; ++i) { | |||||
tshp[i] = op.shape[i]; | |||||
} | |||||
} else { | |||||
auto&& tshp_nd = inputs[1]; | |||||
cg::copy_tensor_value_to_shape( | |||||
tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||||
} | |||||
if (op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||||
mgb_assert(tshp[op.axis] == -1); | |||||
tshp[op.axis] = 1; | |||||
tshp[op.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); | |||||
} | } | ||||
TensorLayout tlayout; | TensorLayout tlayout; | ||||
mgb_assert(slayout.try_reshape(tlayout, tshp)); | mgb_assert(slayout.try_reshape(tlayout, tshp)); | ||||
@@ -188,17 +228,24 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | const OpDef& def, const SmallVector<TensorPtr>& inputs) { | ||||
auto&& op_def = def.cast_final_safe<Reshape>(); | |||||
auto&& op = def.cast_final_safe<Reshape>(); | |||||
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | ||||
layout_checker[0] = [&](const TensorLayout& layout) { | layout_checker[0] = [&](const TensorLayout& layout) { | ||||
TensorShape tshp; | TensorShape tshp; | ||||
TensorLayout ret; | TensorLayout ret; | ||||
cg::copy_tensor_value_to_shape( | |||||
tshp, inputs[1]->get_value().proxy_to_default_cpu()); | |||||
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { | |||||
mgb_assert(tshp[op_def.axis] == -1); | |||||
tshp[op_def.axis] = 1; | |||||
tshp[op_def.axis] = layout.total_nr_elems() / tshp.total_nr_elems(); | |||||
if (inputs.size() == 1) { | |||||
tshp.ndim = op.shape.size(); | |||||
for (size_t i = 0; i < tshp.ndim; ++i) { | |||||
tshp[i] = op.shape[i]; | |||||
} | |||||
} else { | |||||
cg::copy_tensor_value_to_shape( | |||||
tshp, inputs[1]->get_value().proxy_to_default_cpu()); | |||||
} | |||||
if (op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||||
mgb_assert(tshp[op.axis] == -1); | |||||
tshp[op.axis] = 1; | |||||
tshp[op.axis] = layout.total_nr_elems() / tshp.total_nr_elems(); | |||||
} | } | ||||
if (layout.try_reshape(ret, tshp)) { | if (layout.try_reshape(ret, tshp)) { | ||||
return true; | return true; | ||||
@@ -243,8 +243,10 @@ ValueRefList get_var_shape_rule( | |||||
ValueRefList reshape_rule( | ValueRefList reshape_rule( | ||||
const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask, | const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask, | ||||
const Type<ScalarValue>& scalar_type) { | const Type<ScalarValue>& scalar_type) { | ||||
mgb_assert(inputs.size() == 2); | |||||
bool is_scalar = is_scalar_shape(inputs[1]); | |||||
mgb_assert(inputs.size() == 1 || inputs.size() == 2); | |||||
size_t nr_inp = inputs.size(); | |||||
bool is_scalar = (nr_inp == 2 && is_scalar_shape(inputs[1])) || | |||||
(nr_inp == 1 && reshape.shape.size() == 0); | |||||
if (is_scalar) { | if (is_scalar) { | ||||
return {scalar_type.make(imperative::apply( | return {scalar_type.make(imperative::apply( | ||||
reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | ||||
@@ -256,8 +258,10 @@ ValueRefList reshape_rule( | |||||
ValueRefList broadcast_rule( | ValueRefList broadcast_rule( | ||||
const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask, | const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask, | ||||
const Type<ScalarValue>& scalar_type) { | const Type<ScalarValue>& scalar_type) { | ||||
mgb_assert(inputs.size() == 2); | |||||
bool is_scalar = is_scalar_shape(inputs[1]); | |||||
mgb_assert(inputs.size() == 1 || inputs.size() == 2); | |||||
size_t nr_inp = inputs.size(); | |||||
bool is_scalar = (nr_inp == 2 && is_scalar_shape(inputs[1])) || | |||||
(nr_inp == 1 && broadcast.shape.size() == 0); | |||||
if (is_scalar) { | if (is_scalar) { | ||||
return {scalar_type.make(imperative::apply( | return {scalar_type.make(imperative::apply( | ||||
broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | ||||
@@ -250,7 +250,11 @@ def Concat: MgbHashableOp<"Concat", [AxisParam]> { | |||||
); | ); | ||||
} | } | ||||
def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]>; | |||||
def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]> { | |||||
let extraArguments = (ins | |||||
MgbArrayAttr<MgbI32Attr>:$shape | |||||
); | |||||
} | |||||
def Identity: MgbHashableOp<"Identity">; | def Identity: MgbHashableOp<"Identity">; | ||||
@@ -318,7 +322,11 @@ def Dimshuffle: MgbHashableOp<"Dimshuffle"> { | |||||
let results = (outs AnyMemRef); | let results = (outs AnyMemRef); | ||||
} | } | ||||
def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]>; | |||||
def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]> { | |||||
let extraArguments = (ins | |||||
MgbArrayAttr<MgbI32Attr>:$shape | |||||
); | |||||
} | |||||
// TODO: merge Add/Remove Axis into AxisAddRemove as megbrain? | // TODO: merge Add/Remove Axis into AxisAddRemove as megbrain? | ||||
def AddAxis: MgbHashableOp<"AddAxis"> { | def AddAxis: MgbHashableOp<"AddAxis"> { | ||||