GitOrigin-RevId: 70ddc06eee
release-1.5
@@ -86,16 +86,22 @@ def _broadcast(inp, shape): | |||||
def _reshape(x, shape): | def _reshape(x, shape): | ||||
shape_tuple = _make_shape_tuple(shape) | |||||
unspec_axis = None | unspec_axis = None | ||||
# XXX: assume unspec_axis is not changed in trace | |||||
for i, s in enumerate(shape_tuple): | |||||
if s < 0: | |||||
if s != -1: | |||||
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||||
if unspec_axis is not None: | |||||
raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | |||||
unspec_axis = i | |||||
try: | |||||
shape_tuple = _make_shape_tuple(shape) | |||||
except ValueError: | |||||
pass | |||||
else: | |||||
# XXX: assume unspec_axis is not changed in trace | |||||
for i, s in enumerate(shape_tuple): | |||||
if s < 0: | |||||
if s != -1: | |||||
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||||
if unspec_axis is not None: | |||||
raise ValueError( | |||||
"multiple -1 in shape: {} & {}".format(unspec_axis, i) | |||||
) | |||||
unspec_axis = i | |||||
shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) | shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) | ||||
if unspec_axis is None: | if unspec_axis is None: | ||||
op = builtin.Reshape() | op = builtin.Reshape() | ||||
@@ -18,9 +18,9 @@ from .utils import astensor1d, isscalar, make_shape_tuple | |||||
def remove_ellipsis(tensor, tuple_val): | def remove_ellipsis(tensor, tuple_val): | ||||
ndim_sum = tensor.ndim | |||||
cur_sum = 0 | cur_sum = 0 | ||||
pos = -1 | pos = -1 | ||||
has_unkown_ndim_bool_index = False | |||||
for i_idx, i in enumerate(tuple_val): | for i_idx, i in enumerate(tuple_val): | ||||
if i is Ellipsis: | if i is Ellipsis: | ||||
for j in tuple_val[:i_idx:-1]: | for j in tuple_val[:i_idx:-1]: | ||||
@@ -28,10 +28,28 @@ def remove_ellipsis(tensor, tuple_val): | |||||
raise IndexError("only one ellipsis is allowed") | raise IndexError("only one ellipsis is allowed") | ||||
pos = i_idx | pos = i_idx | ||||
else: | else: | ||||
cur_sum += i.ndim if hasattr(i, "ndim") else 1 | |||||
try: | |||||
cur_sum += ( | |||||
i.ndim | |||||
if hasattr(i, "dtype") | |||||
and i.dtype == np.bool_ | |||||
and hasattr(i, "ndim") | |||||
else 1 | |||||
) | |||||
except ValueError: | |||||
has_unkown_ndim_bool_index = True | |||||
if pos == -1: | if pos == -1: | ||||
return tuple_val | return tuple_val | ||||
else: | else: | ||||
if has_unkown_ndim_bool_index: | |||||
raise IndexError( | |||||
"Does not support bool index with unknown shape when using Ellipsis" | |||||
) | |||||
try: | |||||
ndim_sum = tensor.ndim | |||||
except ValueError: | |||||
raise IndexError("Does not support Ellipsis when tensor's ndim is unknown.") | |||||
return ( | return ( | ||||
tuple_val[:pos] | tuple_val[:pos] | ||||
+ (slice(None, None, None),) * (ndim_sum - cur_sum) | + (slice(None, None, None),) * (ndim_sum - cur_sum) | ||||
@@ -41,7 +59,11 @@ def remove_ellipsis(tensor, tuple_val): | |||||
# XXX: assume same results during trace | # XXX: assume same results during trace | ||||
def check_bool_index(tensor, tuple_val): | def check_bool_index(tensor, tuple_val): | ||||
cur_shape = make_shape_tuple(tensor.shape) | |||||
try: | |||||
cur_shape = make_shape_tuple(tensor.shape) | |||||
except ValueError: | |||||
return tensor, tuple_val | |||||
new_tuple_val = [] | new_tuple_val = [] | ||||
offset = 0 | offset = 0 | ||||
tdim = 0 | tdim = 0 | ||||
@@ -92,20 +114,31 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
ndim_indexed_scalar = 0 | ndim_indexed_scalar = 0 | ||||
for i in tuple_val: | for i in tuple_val: | ||||
if not i is Ellipsis: | if not i is Ellipsis: | ||||
ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim | |||||
ndim_indexed += ( | |||||
i.ndim | |||||
if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim") | |||||
else 1 | |||||
) | |||||
if isscalar(i): | if isscalar(i): | ||||
ndim_indexed_scalar += 1 | ndim_indexed_scalar += 1 | ||||
if ndim_indexed > inp.ndim: | |||||
raise IndexError( | |||||
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( | |||||
inp.ndim, ndim_indexed | |||||
ret_scalar = False | |||||
try: | |||||
ret_scalar = ndim_indexed_scalar == inp.ndim | |||||
except ValueError: | |||||
# inp.ndim is unknown | |||||
pass | |||||
else: | |||||
if ndim_indexed > inp.ndim: | |||||
raise IndexError( | |||||
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( | |||||
inp.ndim, len(tuple_val) | |||||
) | |||||
) | ) | ||||
) | |||||
tuple_val = remove_ellipsis(inp, tuple_val) | tuple_val = remove_ellipsis(inp, tuple_val) | ||||
use_subtensor = True | use_subtensor = True | ||||
inp, tuple_val = check_bool_index(inp, tuple_val) | |||||
if inp.shape is not None: | |||||
inp, tuple_val = check_bool_index(inp, tuple_val) | |||||
new_axes = [] | new_axes = [] | ||||
tensors = [] | tensors = [] | ||||
@@ -186,7 +219,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
items.append(item) | items.append(item) | ||||
if new_axes: | if new_axes: | ||||
raise IndexError("newaxis is not allowed here") | raise IndexError("newaxis is not allowed here") | ||||
return inp, tensors, items, use_subtensor, ndim_indexed_scalar == inp.ndim | |||||
return inp, tensors, items, use_subtensor, ret_scalar | |||||
def try_condtake(tensor, index): | def try_condtake(tensor, index): | ||||
@@ -249,16 +282,21 @@ def setitem(tensor, index, value): | |||||
op = builtin.IndexingMultiAxisVec(items=items) | op = builtin.IndexingMultiAxisVec(items=items) | ||||
(tmp_result,) = apply(op, tensor, *tensors) | (tmp_result,) = apply(op, tensor, *tensors) | ||||
for i in range(min(len(value.shape), len(tmp_result.shape))): | |||||
if (value.shape[-i - 1] != 1) & ( | |||||
value.shape[-i - 1] != tmp_result.shape[-i - 1] | |||||
): | |||||
raise ValueError( | |||||
"cannot copy tensor with shape {} to subtensor with shape {}".format( | |||||
value.shape, tmp_result.shape | |||||
try: | |||||
value_shape = value._tuple_shape | |||||
tmp_result_shape = tmp_result._tuple_shape | |||||
except ValueError: | |||||
pass | |||||
else: | |||||
for i in range(min(len(value_shape), len(tmp_result_shape))): | |||||
if (value_shape[-i - 1] != 1) & ( | |||||
value_shape[-i - 1] != tmp_result_shape[-i - 1] | |||||
): | |||||
raise ValueError( | |||||
"cannot copy tensor with shape {} to subtensor with shape {}".format( | |||||
value_shape, tmp_result_shape | |||||
) | |||||
) | ) | ||||
) | |||||
value = value._broadcast(tmp_result.shape) | value = value._broadcast(tmp_result.shape) | ||||
if use_subtensor: | if use_subtensor: | ||||
@@ -137,6 +137,13 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
ndim = x.ndim | ndim = x.ndim | ||||
except AttributeError: | except AttributeError: | ||||
pass | pass | ||||
except ValueError: | |||||
if dtype is not None and dtype != x.dtype: | |||||
x = astype(x, dtype) | |||||
if device is not None: | |||||
cn = as_device(device).to_c() | |||||
(x,) = apply(builtin.Copy(comp_node=cn), x) | |||||
return x | |||||
else: | else: | ||||
if ndim != 0 and ndim != 1: | if ndim != 0 and ndim != 1: | ||||
raise ValueError("ndim != 1 or 0, get : %d" % ndim) | raise ValueError("ndim != 1 or 0, get : %d" % ndim) | ||||
@@ -148,7 +155,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
raise TypeError | raise TypeError | ||||
if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | ||||
x = concatenate(x, device=device) | |||||
x = concatenate(x, device=device) if len(x) > 1 else x[0] | |||||
if dtype is not None: | if dtype is not None: | ||||
x = astype(x, dtype) | x = astype(x, dtype) | ||||
return x | return x | ||||
@@ -849,8 +849,15 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||||
return list(map(int, axis)) | return list(map(int, axis)) | ||||
axis = get_axes() | axis = get_axes() | ||||
ndim = inp.ndim + len(axis) | |||||
axis = sorted(i + ndim if i < 0 else i for i in axis) | |||||
try: | |||||
ndim = inp.ndim + len(axis) | |||||
axis = sorted(i + ndim if i < 0 else i for i in axis) | |||||
except ValueError: | |||||
if any([ind < 0 for ind in axis]): | |||||
raise IndexError( | |||||
"Does not support negative index when tensor's ndim is unknown" | |||||
) | |||||
axis = sorted(axis) | |||||
assert axis, "axis could not be empty" | assert axis, "axis could not be empty" | ||||
if inp._isscalar(): | if inp._isscalar(): | ||||
assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0]) | assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0]) | ||||
@@ -384,6 +384,11 @@ PyObject* TensorWrapper::shape() { | |||||
TensorShape shape; | TensorShape shape; | ||||
if (m_tensor->m_var) { // get shape from m_var | if (m_tensor->m_var) { // get shape from m_var | ||||
auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); | auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); | ||||
auto&& type = mgr.get_infer_type(m_tensor->m_var); | |||||
using InferType = cg::static_infer::InferType; | |||||
if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) { | |||||
Py_RETURN_NONE; | |||||
} | |||||
auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var); | auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var); | ||||
if (!tshp) { | if (!tshp) { | ||||
Py_RETURN_NONE; | Py_RETURN_NONE; | ||||
@@ -878,6 +883,24 @@ void init_tensor(py::module m) { | |||||
->static_infer_manager(); | ->static_infer_manager(); | ||||
return mgr.infer_shape_fallible(v->m_node); | return mgr.infer_shape_fallible(v->m_node); | ||||
}) | }) | ||||
.def("numpy", [](PySymbolVar* v){ | |||||
auto&& mgr = v->m_node->owner_graph()->static_infer_manager(); | |||||
auto&& type = mgr.get_infer_type(v->m_node); | |||||
using InferType = cg::static_infer::InferType; | |||||
if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { | |||||
throw py::value_error("value invalid!"); | |||||
} | |||||
auto* val = mgr.infer_value_fallible(v->m_node); | |||||
if (!val) { | |||||
throw py::value_error("value invalid!"); | |||||
} | |||||
auto np_val = py::cast(*val).attr("numpy")(); | |||||
if (v->is_scalar) { | |||||
return py::object(py::array(np_val).squeeze()); | |||||
} | |||||
return np_val; | |||||
}) | |||||
.def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | ||||
.def("_setscalar", | .def("_setscalar", | ||||
[](PySymbolVar* v) { return v->is_scalar = true; }) | [](PySymbolVar* v) { return v->is_scalar = true; }) | ||||