|
|
@@ -51,6 +51,7 @@ bool is_scalar_shape(ValueRef shape) { |
|
|
|
if (shape.is<ScalarValue>()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
// may have performance issue |
|
|
|
auto shape_of_shape = shape.shape(); |
|
|
|
if (!shape_of_shape) { |
|
|
|
// assume not scalar |
|
|
@@ -211,14 +212,21 @@ std::vector<ValueRef> subtensor_rule( |
|
|
|
const Subtensor& subtensor, Span<ValueRef> inputs) { |
|
|
|
mgb_assert(inputs.size() >= 1); |
|
|
|
auto input = inputs[0]; |
|
|
|
size_t ndim = input.is<ScalarValue>() ? 0 : input.shape()->ndim; |
|
|
|
for (auto&& [axis, begin, end, step, idx] : subtensor.items) { |
|
|
|
if (idx) { |
|
|
|
ndim--; |
|
|
|
bool is_scalar; |
|
|
|
mgb_assert(!input.is<ScalarValue>(), "subtensor shouldn't have scalar input"); |
|
|
|
if (auto shape = input.shape()) { |
|
|
|
size_t ndim = input.shape()->ndim; |
|
|
|
for (auto&& [axis, begin, end, step, idx] : subtensor.items) { |
|
|
|
if (idx) { |
|
|
|
ndim--; |
|
|
|
} |
|
|
|
} |
|
|
|
is_scalar = ndim == 0; |
|
|
|
} else { |
|
|
|
is_scalar = false; |
|
|
|
} |
|
|
|
auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0]; |
|
|
|
if (!ndim) { |
|
|
|
if (is_scalar) { |
|
|
|
return {ScalarValue::make(output)}; |
|
|
|
} else { |
|
|
|
return {output}; |
|
|
@@ -261,8 +269,7 @@ std::vector<ValueRef> fastpath_copy_rule( |
|
|
|
|
|
|
|
std::vector<ValueRef> reshape_rule(const Reshape& reshape, Span<ValueRef> inputs) { |
|
|
|
mgb_assert(inputs.size() == 2); |
|
|
|
bool is_scalar = |
|
|
|
(!inputs[1].is<ScalarValue>()) && *inputs[1].shape() == ValueShape{0}; |
|
|
|
bool is_scalar = is_scalar_shape(inputs[1]); |
|
|
|
auto unwrapped_input = inputs[0].is<ScalarValue>() |
|
|
|
? inputs[0].cast<ScalarValue>().value() |
|
|
|
: inputs[0]; |
|
|
|