GitOrigin-RevId: 0fc40ad9c8
tags/v1.3.0
@@ -234,15 +234,10 @@ def setitem(tensor, index, value): | |||||
try_result = try_condtake(tensor, index) | try_result = try_condtake(tensor, index) | ||||
if len(try_result) == 2: | if len(try_result) == 2: | ||||
index = try_result[1] | index = try_result[1] | ||||
if index.shape[0] == 0: | |||||
return tensor | |||||
tensor = tensor.reshape(-1) | tensor = tensor.reshape(-1) | ||||
if not isinstance(value, Tensor): | if not isinstance(value, Tensor): | ||||
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)() | (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)() | ||||
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | ||||
for v in tensors: | |||||
if len(v.shape) > 0 and v.shape[0] == 0: | |||||
return tensor | |||||
if use_subtensor: | if use_subtensor: | ||||
op = builtin.Subtensor(items=items) | op = builtin.Subtensor(items=items) | ||||
else: | else: | ||||
@@ -250,19 +245,17 @@ def setitem(tensor, index, value): | |||||
(tmp_result,) = apply(op, tensor, *tensors) | (tmp_result,) = apply(op, tensor, *tensors) | ||||
# XXX: broadcast can always be applied even if shapes are equal | |||||
if make_shape_tuple(value.shape) != make_shape_tuple(tmp_result.shape): | |||||
for i in range(min(len(value.shape), len(tmp_result.shape))): | |||||
if ( | |||||
value.shape[-i - 1] != 1 | |||||
and value.shape[-i - 1] != tmp_result.shape[-i - 1] | |||||
): | |||||
raise ValueError( | |||||
"cannot copy tensor with shape {} to subtensor with shape {}".format( | |||||
value.shape, tmp_result.shape | |||||
) | |||||
for i in range(min(len(value.shape), len(tmp_result.shape))): | |||||
if (value.shape[-i - 1] != 1) & ( | |||||
value.shape[-i - 1] != tmp_result.shape[-i - 1] | |||||
): | |||||
raise ValueError( | |||||
"cannot copy tensor with shape {} to subtensor with shape {}".format( | |||||
value.shape, tmp_result.shape | |||||
) | ) | ||||
value = value._broadcast(tmp_result.shape) | |||||
) | |||||
value = value._broadcast(tmp_result.shape) | |||||
if use_subtensor: | if use_subtensor: | ||||
op = builtin.SetSubtensor(items=items) | op = builtin.SetSubtensor(items=items) | ||||
else: | else: | ||||
@@ -644,12 +644,7 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||||
v0, index0 = cond_take(mask, x) | v0, index0 = cond_take(mask, x) | ||||
v1, index1 = cond_take(~mask, y) | v1, index1 = cond_take(~mask, y) | ||||
if v0.shape == (0,): | |||||
out = v1 | |||||
elif v1.shape == (0,): | |||||
out = v0 | |||||
else: | |||||
out = concat([v0, v1]) | |||||
out = concat([v0, v1]) | |||||
out[index0] = v0 | out[index0] = v0 | ||||
out[index1] = v1 | out[index1] = v1 | ||||
@@ -85,7 +85,8 @@ public: | |||||
var->m_comp_node = dev_tensor.comp_node(); | var->m_comp_node = dev_tensor.comp_node(); | ||||
var->m_shape = dev_tensor.shape(); | var->m_shape = dev_tensor.shape(); | ||||
var->m_dev_tensor = dev_tensor; | var->m_dev_tensor = dev_tensor; | ||||
var->reset_dev_tensor_from_tensor(dev_tensor); | |||||
var->m_mem_plan.reset_from_owner_var().chunk() | |||||
.mem_alloc_status.set_from_owner_var(); | |||||
return var; | return var; | ||||
} | } | ||||
@@ -560,7 +561,11 @@ void ProxyGraph::init_output_tensor(const SmallVector<Tensor*>& outputs) { | |||||
mgb_assert(var->comp_node() == tensor->comp_node() && | mgb_assert(var->comp_node() == tensor->comp_node() && | ||||
var->shape().eq_shape(layout) && | var->shape().eq_shape(layout) && | ||||
var->dtype() == layout.dtype); | var->dtype() == layout.dtype); | ||||
var->assign_dev_tensor_from_tensor(tensor->dev_tensor()); | |||||
if (!tensor->layout().is_empty()) { | |||||
var->assign_dev_tensor_from_tensor(tensor->dev_tensor()); | |||||
} else { | |||||
var->m_dev_tensor.storage({var->comp_node()}); | |||||
} | |||||
++ j; | ++ j; | ||||
} | } | ||||
chk.mem_alloc_status.set_from_owner_var(); | chk.mem_alloc_status.set_from_owner_var(); | ||||
@@ -365,6 +365,9 @@ WARN(IndexingIncrMultiAxisVec); | |||||
template <class Opr> | template <class Opr> | ||||
void IndexingMultiAxisVecBase<Opr>::scn_do_execute() { | void IndexingMultiAxisVecBase<Opr>::scn_do_execute() { | ||||
if (output(0)->layout().is_empty()) { | |||||
return; | |||||
} | |||||
auto inp = input(0)->dev_tensor(); | auto inp = input(0)->dev_tensor(); | ||||
inp = inp.sub(fancy_indexing_make_sub_spec(inp.layout())); | inp = inp.sub(fancy_indexing_make_sub_spec(inp.layout())); | ||||
auto &&index_desc = make_megdnn_index_desc( | auto &&index_desc = make_megdnn_index_desc( | ||||
@@ -81,6 +81,11 @@ void ReadonlyFwdHelper::mixin_rofwd_init_mem_plan(OperatorNodeBase &opr) { | |||||
void ReadonlyFwdHelper::mixin_rofwd_execute(OperatorNodeBase &opr) { | void ReadonlyFwdHelper::mixin_rofwd_execute(OperatorNodeBase &opr) { | ||||
mgb_assert(m_rofwd_subspec.layout().ndim, "rofwd uninitialized"); | mgb_assert(m_rofwd_subspec.layout().ndim, "rofwd uninitialized"); | ||||
if (m_rofwd_subspec.layout().is_empty()) { | |||||
mgb_assert(opr.output(0)->shape().is_empty(), "output layout mismatch"); | |||||
return; | |||||
} | |||||
auto &&out = opr.output(0)->dev_tensor(), | auto &&out = opr.output(0)->dev_tensor(), | ||||
&&inp = opr.input(0)->dev_tensor(); | &&inp = opr.input(0)->dev_tensor(); | ||||
if (m_mem_fwd_success) { | if (m_mem_fwd_success) { | ||||