GitOrigin-RevId: 70ddc06eee
release-1.5
@@ -86,16 +86,22 @@ def _broadcast(inp, shape): | |||
def _reshape(x, shape): | |||
shape_tuple = _make_shape_tuple(shape) | |||
unspec_axis = None | |||
# XXX: assume unspec_axis is not changed in trace | |||
for i, s in enumerate(shape_tuple): | |||
if s < 0: | |||
if s != -1: | |||
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||
if unspec_axis is not None: | |||
raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | |||
unspec_axis = i | |||
try: | |||
shape_tuple = _make_shape_tuple(shape) | |||
except ValueError: | |||
pass | |||
else: | |||
# XXX: assume unspec_axis is not changed in trace | |||
for i, s in enumerate(shape_tuple): | |||
if s < 0: | |||
if s != -1: | |||
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||
if unspec_axis is not None: | |||
raise ValueError( | |||
"multiple -1 in shape: {} & {}".format(unspec_axis, i) | |||
) | |||
unspec_axis = i | |||
shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) | |||
if unspec_axis is None: | |||
op = builtin.Reshape() | |||
@@ -18,9 +18,9 @@ from .utils import astensor1d, isscalar, make_shape_tuple | |||
def remove_ellipsis(tensor, tuple_val): | |||
ndim_sum = tensor.ndim | |||
cur_sum = 0 | |||
pos = -1 | |||
has_unkown_ndim_bool_index = False | |||
for i_idx, i in enumerate(tuple_val): | |||
if i is Ellipsis: | |||
for j in tuple_val[:i_idx:-1]: | |||
@@ -28,10 +28,28 @@ def remove_ellipsis(tensor, tuple_val): | |||
raise IndexError("only one ellipsis is allowed") | |||
pos = i_idx | |||
else: | |||
cur_sum += i.ndim if hasattr(i, "ndim") else 1 | |||
try: | |||
cur_sum += ( | |||
i.ndim | |||
if hasattr(i, "dtype") | |||
and i.dtype == np.bool_ | |||
and hasattr(i, "ndim") | |||
else 1 | |||
) | |||
except ValueError: | |||
has_unkown_ndim_bool_index = True | |||
if pos == -1: | |||
return tuple_val | |||
else: | |||
if has_unkown_ndim_bool_index: | |||
raise IndexError( | |||
"Does not support bool index with unknown shape when using Ellipsis" | |||
) | |||
try: | |||
ndim_sum = tensor.ndim | |||
except ValueError: | |||
raise IndexError("Does not support Ellipsis when tensor's ndim is unknown.") | |||
return ( | |||
tuple_val[:pos] | |||
+ (slice(None, None, None),) * (ndim_sum - cur_sum) | |||
@@ -41,7 +59,11 @@ def remove_ellipsis(tensor, tuple_val): | |||
# XXX: assume same results during trace | |||
def check_bool_index(tensor, tuple_val): | |||
cur_shape = make_shape_tuple(tensor.shape) | |||
try: | |||
cur_shape = make_shape_tuple(tensor.shape) | |||
except ValueError: | |||
return tensor, tuple_val | |||
new_tuple_val = [] | |||
offset = 0 | |||
tdim = 0 | |||
@@ -92,20 +114,31 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
ndim_indexed_scalar = 0 | |||
for i in tuple_val: | |||
if not i is Ellipsis: | |||
ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim | |||
ndim_indexed += ( | |||
i.ndim | |||
if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim") | |||
else 1 | |||
) | |||
if isscalar(i): | |||
ndim_indexed_scalar += 1 | |||
if ndim_indexed > inp.ndim: | |||
raise IndexError( | |||
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( | |||
inp.ndim, ndim_indexed | |||
ret_scalar = False | |||
try: | |||
ret_scalar = ndim_indexed_scalar == inp.ndim | |||
except ValueError: | |||
# inp.ndim is unknown | |||
pass | |||
else: | |||
if ndim_indexed > inp.ndim: | |||
raise IndexError( | |||
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( | |||
inp.ndim, len(tuple_val) | |||
) | |||
) | |||
) | |||
tuple_val = remove_ellipsis(inp, tuple_val) | |||
use_subtensor = True | |||
inp, tuple_val = check_bool_index(inp, tuple_val) | |||
if inp.shape is not None: | |||
inp, tuple_val = check_bool_index(inp, tuple_val) | |||
new_axes = [] | |||
tensors = [] | |||
@@ -186,7 +219,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
items.append(item) | |||
if new_axes: | |||
raise IndexError("newaxis is not allowed here") | |||
return inp, tensors, items, use_subtensor, ndim_indexed_scalar == inp.ndim | |||
return inp, tensors, items, use_subtensor, ret_scalar | |||
def try_condtake(tensor, index): | |||
@@ -249,16 +282,21 @@ def setitem(tensor, index, value): | |||
op = builtin.IndexingMultiAxisVec(items=items) | |||
(tmp_result,) = apply(op, tensor, *tensors) | |||
for i in range(min(len(value.shape), len(tmp_result.shape))): | |||
if (value.shape[-i - 1] != 1) & ( | |||
value.shape[-i - 1] != tmp_result.shape[-i - 1] | |||
): | |||
raise ValueError( | |||
"cannot copy tensor with shape {} to subtensor with shape {}".format( | |||
value.shape, tmp_result.shape | |||
try: | |||
value_shape = value._tuple_shape | |||
tmp_result_shape = tmp_result._tuple_shape | |||
except ValueError: | |||
pass | |||
else: | |||
for i in range(min(len(value_shape), len(tmp_result_shape))): | |||
if (value_shape[-i - 1] != 1) & ( | |||
value_shape[-i - 1] != tmp_result_shape[-i - 1] | |||
): | |||
raise ValueError( | |||
"cannot copy tensor with shape {} to subtensor with shape {}".format( | |||
value_shape, tmp_result_shape | |||
) | |||
) | |||
) | |||
value = value._broadcast(tmp_result.shape) | |||
if use_subtensor: | |||
@@ -137,6 +137,13 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
ndim = x.ndim | |||
except AttributeError: | |||
pass | |||
except ValueError: | |||
if dtype is not None and dtype != x.dtype: | |||
x = astype(x, dtype) | |||
if device is not None: | |||
cn = as_device(device).to_c() | |||
(x,) = apply(builtin.Copy(comp_node=cn), x) | |||
return x | |||
else: | |||
if ndim != 0 and ndim != 1: | |||
raise ValueError("ndim != 1 or 0, get : %d" % ndim) | |||
@@ -148,7 +155,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
raise TypeError | |||
if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | |||
x = concatenate(x, device=device) | |||
x = concatenate(x, device=device) if len(x) > 1 else x[0] | |||
if dtype is not None: | |||
x = astype(x, dtype) | |||
return x | |||
@@ -849,8 +849,15 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
return list(map(int, axis)) | |||
axis = get_axes() | |||
ndim = inp.ndim + len(axis) | |||
axis = sorted(i + ndim if i < 0 else i for i in axis) | |||
try: | |||
ndim = inp.ndim + len(axis) | |||
axis = sorted(i + ndim if i < 0 else i for i in axis) | |||
except ValueError: | |||
if any([ind < 0 for ind in axis]): | |||
raise IndexError( | |||
"Does not support negative index when tensor's ndim is unknown" | |||
) | |||
axis = sorted(axis) | |||
assert axis, "axis could not be empty" | |||
if inp._isscalar(): | |||
assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0]) | |||
@@ -384,6 +384,11 @@ PyObject* TensorWrapper::shape() { | |||
TensorShape shape; | |||
if (m_tensor->m_var) { // get shape from m_var | |||
auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); | |||
auto&& type = mgr.get_infer_type(m_tensor->m_var); | |||
using InferType = cg::static_infer::InferType; | |||
if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) { | |||
Py_RETURN_NONE; | |||
} | |||
auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var); | |||
if (!tshp) { | |||
Py_RETURN_NONE; | |||
@@ -878,6 +883,24 @@ void init_tensor(py::module m) { | |||
->static_infer_manager(); | |||
return mgr.infer_shape_fallible(v->m_node); | |||
}) | |||
.def("numpy", [](PySymbolVar* v){ | |||
auto&& mgr = v->m_node->owner_graph()->static_infer_manager(); | |||
auto&& type = mgr.get_infer_type(v->m_node); | |||
using InferType = cg::static_infer::InferType; | |||
if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { | |||
throw py::value_error("value invalid!"); | |||
} | |||
auto* val = mgr.infer_value_fallible(v->m_node); | |||
if (!val) { | |||
throw py::value_error("value invalid!"); | |||
} | |||
auto np_val = py::cast(*val).attr("numpy")(); | |||
if (v->is_scalar) { | |||
return py::object(py::array(np_val).squeeze()); | |||
} | |||
return np_val; | |||
}) | |||
.def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | |||
.def("_setscalar", | |||
[](PySymbolVar* v) { return v->is_scalar = true; }) | |||