|
|
@@ -924,78 +924,67 @@ bool enable_fastpath(py::handle inp) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) { |
|
|
|
py::object shape_hdl = _expand_args(args); |
|
|
|
bool auto_infer = false; |
|
|
|
py::list lis; |
|
|
|
py::list new_shape; |
|
|
|
if (PyList_Check(shape_hdl.ptr()) || PyTuple_Check(shape_hdl.ptr())) { |
|
|
|
lis = py::reinterpret_steal<py::list>(PySequence_List(shape_hdl.ptr())); |
|
|
|
for (size_t i = 0; i < lis.size(); ++i) { |
|
|
|
if (lis[i].is_none()) { |
|
|
|
auto_infer = true; |
|
|
|
size_t right = lis.size() - i; |
|
|
|
py::object tshp = getattr(inp_hdl, "_tuple_shape"); |
|
|
|
if (tshp.is_none()) { |
|
|
|
throw py::index_error("does not support `None` with unknown shape"); |
|
|
|
py::object _broadcast_cpp(py::handle input, py::handle args) { |
|
|
|
py::object shape = _expand_args(args); |
|
|
|
py::list dims; |
|
|
|
bool all_imm; |
|
|
|
if (PyList_Check(shape.ptr()) || PyTuple_Check(shape.ptr())) { |
|
|
|
dims = py::reinterpret_steal<py::list>(PySequence_List(shape.ptr())); |
|
|
|
mgb_assert(!dims.is_none()); |
|
|
|
all_imm = true; |
|
|
|
py::object inp_shape = py::none(); |
|
|
|
size_t inp_ndim; |
|
|
|
for (size_t i = 0; i < dims.size(); ++i) { |
|
|
|
py::object dim = dims[i]; |
|
|
|
if (dim.is_none()) { |
|
|
|
ptrdiff_t right = (ptrdiff_t)i - dims.size(); |
|
|
|
if (inp_shape.is_none()) { |
|
|
|
inp_shape = input.attr("shape"); |
|
|
|
mgb_assert(!inp_shape.is_none()); |
|
|
|
inp_ndim = py::len(inp_shape); |
|
|
|
} |
|
|
|
py::tuple inp_shape = py::reinterpret_borrow<py::tuple>(tshp); |
|
|
|
if (inp_shape.size() >= right) { |
|
|
|
if (enable_fastpath(inp_hdl)) { |
|
|
|
lis[i] = inp_shape[inp_shape.size() - right]; |
|
|
|
} |
|
|
|
new_shape.append(inp_shape[inp_shape.size() - right]); |
|
|
|
} else { |
|
|
|
throw py::value_error("invalid broadcast shape"); |
|
|
|
if ((ptrdiff_t)inp_ndim + right < 0) { |
|
|
|
throw py::value_error("size connot be `None` for new axis"); |
|
|
|
} |
|
|
|
} else { |
|
|
|
new_shape.append(lis[i]); |
|
|
|
if (PyLong_Check(lis[i].ptr())) { |
|
|
|
int32_t s = lis[i].cast<int32_t>(); |
|
|
|
if (s < 0) { |
|
|
|
throw py::value_error( |
|
|
|
"expect shape[" + std::to_string(i) + |
|
|
|
"] >= 0 or use `None` to auto infer, got " + |
|
|
|
std::to_string(s)); |
|
|
|
} |
|
|
|
dim = inp_shape.attr("__getitem__")(right); |
|
|
|
dims[i] = dim; |
|
|
|
} |
|
|
|
if (py::int_::check_(dim)) { |
|
|
|
if (dim.cast<long>() < 0) { |
|
|
|
throw py::value_error(ssprintf( |
|
|
|
"expect shape[%zu] >= 0 or use `None` to auto infer, got " |
|
|
|
"%s", |
|
|
|
i, py::repr(dims[i]).cast<std::string>().c_str())); |
|
|
|
} |
|
|
|
} else { |
|
|
|
all_imm = false; |
|
|
|
} |
|
|
|
} |
|
|
|
shape = dims; |
|
|
|
} else { |
|
|
|
all_imm = false; |
|
|
|
} |
|
|
|
if (auto_infer) { |
|
|
|
if (enable_fastpath(inp_hdl)) { |
|
|
|
shape_hdl = py::reinterpret_borrow<py::tuple>(lis); |
|
|
|
} else { |
|
|
|
shape_hdl = _astensor1d_cpp( |
|
|
|
new_shape, py::cast((mgb::DType)dtype::Int32()), |
|
|
|
getattr(inp_hdl, "device"), inp_hdl); |
|
|
|
} |
|
|
|
bool fastpath = all_imm && enable_fastpath(input); |
|
|
|
if ((!fastpath) && (!is_tensor(shape))) { |
|
|
|
shape = _astensor1d_cpp( |
|
|
|
shape, py::cast((mgb::DType)dtype::Int32()), input.attr("device"), |
|
|
|
input); |
|
|
|
} |
|
|
|
py::object shape_tuple; |
|
|
|
try { |
|
|
|
shape_tuple = _make_shape_tuple(shape_hdl); |
|
|
|
} catch (py::error_already_set& err) { |
|
|
|
shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl); |
|
|
|
} |
|
|
|
auto [shape, fastpath] = tuple2vector(shape_tuple); |
|
|
|
fastpath &= enable_fastpath(inp_hdl); |
|
|
|
std::shared_ptr<OpDef> op; |
|
|
|
std::vector<PyObject*> p; |
|
|
|
py::object shape_tensor; |
|
|
|
SmallVector<PyObject*> p(2); |
|
|
|
if (fastpath) { |
|
|
|
op = Broadcast::make(shape); |
|
|
|
p.resize(2); |
|
|
|
std::vector<int32_t> shape_vec; |
|
|
|
for (auto&& dim : dims) { |
|
|
|
shape_vec.push_back(dim.cast<long>()); |
|
|
|
} |
|
|
|
op = Broadcast::make(shape_vec); |
|
|
|
} else { |
|
|
|
op = Broadcast::make(); |
|
|
|
shape_tensor = _astensor1d_cpp( |
|
|
|
shape_hdl, py::cast((mgb::DType)dtype::Int32()), |
|
|
|
getattr(inp_hdl, "device"), inp_hdl); |
|
|
|
p.resize(3); |
|
|
|
p[2] = shape_tensor.ptr(); |
|
|
|
p.push_back(shape.ptr()); |
|
|
|
} |
|
|
|
py::object Op = py::cast(op); |
|
|
|
p[0] = Op.ptr(); |
|
|
|
p[1] = inp_hdl.ptr(); |
|
|
|
py::object py_op = py::cast(op); |
|
|
|
p[0] = py_op.ptr(); |
|
|
|
p[1] = input.ptr(); |
|
|
|
py::tuple ret = |
|
|
|
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); |
|
|
|
return ret[0]; |
|
|
@@ -1675,4 +1664,4 @@ PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs) { |
|
|
|
PYEXT17_TRANSLATE_EXC_RET(nullptr) |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace mgb::imperative::python |
|
|
|
} // namespace mgb::imperative::python |