Browse Source

feat(mge/opr): add meshgrid opr

GitOrigin-RevId: 6f703295be
release-1.11.1
Megvii Engine Team 2 years ago
parent
commit
d3b2b51918
8 changed files with 320 additions and 11 deletions
  1. +48
    -1
      imperative/python/megengine/functional/tensor.py
  2. +119
    -5
      imperative/src/impl/ops/broadcast.cpp
  3. +5
    -5
      imperative/tablegen/generated/hash.txt
  4. +37
    -0
      imperative/tablegen/generated/opdef.cpp.inl
  5. +90
    -0
      imperative/tablegen/generated/opdef.cpy.inl
  6. +9
    -0
      imperative/tablegen/generated/opdef.h.inl
  7. +7
    -0
      imperative/tablegen/generated/opdef.py.inl
  8. +5
    -0
      src/core/include/megbrain/ir/ops.td

+ 48
- 1
imperative/python/megengine/functional/tensor.py View File

@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
from functools import lru_cache
from typing import Iterable, Optional, Sequence, Tuple, Union
from typing import Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np

@@ -36,6 +36,7 @@ __all__ = [
"full_like",
"gather",
"linspace",
"meshgrid",
"ones",
"ones_like",
"repeat",
@@ -1205,3 +1206,49 @@ def cumsum(inp: Tensor, axis: int):
assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor"
op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False)
return apply(op, inp)[0]


def meshgrid(*inputs: Tensor, indexing: str = "xy") -> List[Tensor]:
r"""Returns coordinate matrices from coordinate vectors.

Args:
inputs: an arbitrary number of one-dimensional tensors representing grid
coordinates. Each input should have the same numeric data type.
indexing: Cartesian ``'xy'`` or matrix ``'ij'`` indexing of output.
If provided zero or one one-dimensional vector(s) (i.e., the zero- and one-dimensional
cases, respectively), the indexing keyword has no effect and should be ignored.


Returns:
out: list of N tensors, where N is the number of provided one-dimensional input tensors.
Each returned tensor must have rank N. For N one-dimensional tensors having lengths ``Ni = len(xi)``,
* if matrix indexing ``ij``, then each returned tensor must have the shape ``(N1, N2, N3, ..., Nn)``.
* if Cartesian indexing ``xy``, then each returned tensor must have shape ``(N2, N1, N3, ..., Nn)``.
Accordingly, for the two-dimensional case with input one-dimensional tensors of length ``M`` and ``N``,
if matrix indexing ``ij``, then each returned tensor must have shape ``(M, N)``, and, if Cartesian indexing ``xy``,
then each returned tensor must have shape ``(N, M)``.

Similarly, for the three-dimensional case with input one-dimensional tensor of length ``M``, ``N``, and ``P``,
if matrix indexing ``ij``, then each returned tensor must have shape ``(M, N, P)``, and, if Cartesian indexing ``xy``,
then each returned tensor must have shape ``(N, M, P)``.

Each returned tensor should have the same data type as the input tensors.
Examples:
>>> nx, ny = (3, 2)
>>> x = F.linspace(0, 1, nx)
>>> y = F.linspace(0, 1, ny)
>>> xv, yv = F.meshgrid(x, y)
>>> xv
Tensor([[0. 0.5 1. ]
[0. 0.5 1. ]], device=xpux:0)
>>> yv
Tensor([[0. 0. 0.]
[1. 1. 1.]], device=xpux:0)


"""
op = builtin.MeshGrid(indexing)
return apply(op, *inputs)

+ 119
- 5
imperative/src/impl/ops/broadcast.cpp View File

@@ -1,13 +1,129 @@
#include <numeric>
#include "megbrain/graph/helper.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"

#include "megbrain/graph/helper.h"

#include "../op_trait.h"

namespace mgb {
namespace imperative {
namespace meshgrid {
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
return SmallVector<VarNode::LayoutConstraintCallback>(inputs.size());
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
for (size_t i = 0; i < inputs.size() - 1; i++) {
mgb_assert(inputs[i].layout.dtype == inputs[i + 1].layout.dtype);
mgb_assert(inputs[i].comp_node == inputs[i + 1].comp_node);
}
auto&& op = def.cast_final_safe<MeshGrid>();
mgb_assert(op.indexing == "xy" || op.indexing == "ij");
bool success = true;
SmallVector<size_t> shp;
for (size_t i = 0; i < inputs.size(); i++) {
mgb_assert(inputs[i].layout.ndim <= 1);
if (inputs[i].layout.ndim == 0) {
success = false;
}
shp.push_back(inputs[i].layout.total_nr_elems());
}
if (op.indexing == "xy" and shp.size() >= 2) {
std::swap(shp[0], shp[1]);
}
TensorShape tshp(shp);
SmallVector<LogicalTensorDesc> descs;

for (size_t i = 0; i < inputs.size(); i++) {
if (success) {
descs.push_back(
{TensorLayout(tshp, inputs[0].layout.dtype), inputs[0].comp_node});
} else {
descs.push_back(
{TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node});
}
}
return {descs, success};
}
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<MeshGrid>();
std::vector<size_t> indexs(inputs.size());
std::iota(indexs.begin(), indexs.end(), 0);
auto cn = inputs[0]->comp_node();
auto graph = inputs[0]->owner_graph();
if (op.indexing == "xy") {
if (indexs.size() >= 2) {
std::swap(indexs[0], indexs[1]);
}
} else {
mgb_assert(op.indexing == "ij", "meshgrid only support \"ij\" or \"xy\"");
}
VarNodeArray shps;
for (size_t ind = 0; ind < inputs.size(); ind++) {
auto&& inp = inputs[indexs[ind]];
shps.push_back(opr::GetVarShape::make(inp).node());
}
VarNode* tshp = opr::Concat::make(shps, 0, cn).node();
VarNodeArray results;
auto t_ndim = inputs.size();
for (size_t ind = 0; ind < inputs.size(); ind++) {
auto axis = indexs[ind];
HostTensorND hv = HostTensorND(cn, {t_ndim}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
std::fill_n(ptr, t_ndim, 1);
ptr[axis] = -1;
auto shp = opr::ImmutableTensor::make(*graph, hv, cn).node();
auto tmp = opr::Reshape::make(inputs[ind], shp, axis).node();
results.push_back(opr::Broadcast::make(tmp, tshp).node());
}
return results;
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = def.cast_final_safe<MeshGrid>();
TensorShape tshp;
TensorShape view_shp;
tshp.ndim = inputs.size();
view_shp.ndim = inputs.size();
std::vector<size_t> indexs(inputs.size());
std::iota(indexs.begin(), indexs.end(), 0);

if (op.indexing == "xy") {
if (indexs.size() >= 2) {
std::swap(indexs[0], indexs[1]);
}
} else {
mgb_assert(op.indexing == "ij", "meshgrid only support \"ij\" or \"xy\"");
}
for (size_t ind = 0; ind < inputs.size(); ind++) {
auto&& inp = inputs[indexs[ind]];
mgb_assert(inp->layout().ndim <= 1);
tshp[ind] = inp->layout().total_nr_elems();
view_shp[ind] = 1;
}
SmallVector<TensorPtr> grids;
for (size_t i = 0; i < inputs.size(); i++) {
auto&& src = inputs[i];
TensorLayout layout;
view_shp[indexs[i]] = src->layout().total_nr_elems();
mgb_assert(src->layout().try_reshape(layout, view_shp));
layout = layout.broadcast(tshp);
view_shp[indexs[i]] = 1;
grids.push_back(Tensor::make(src->blob(), src->offset(), layout));
}
return grids;
}
OP_TRAIT_REG(MeshGrid, MeshGrid)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.get_input_layout_constraint(get_input_layout_constraint)
.fallback();
} // namespace meshgrid
namespace broadcast {

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
@@ -211,7 +327,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
tshp, tshp_nd->get_value().proxy_to_default_cpu());
}
if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(tshp[op.axis] == -1);
tshp[op.axis] = 1;
tshp[op.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems();
}
@@ -237,7 +352,6 @@ SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
tshp, inputs[1]->get_value().proxy_to_default_cpu());
}
if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(tshp[op.axis] == -1);
tshp[op.axis] = 1;
tshp[op.axis] = layout.total_nr_elems() / tshp.total_nr_elems();
}
@@ -250,7 +364,7 @@ SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
return layout_checker;
}

OP_TRAIT_REG(Reshape, Reshape)
OP_TRAIT_REG(Reshape, Reshape, opr::Reshape)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)


+ 5
- 5
imperative/tablegen/generated/hash.txt View File

@@ -1,7 +1,7 @@
905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py
e35e13523f43b7bea4034a0bf75937b7 ../../src/core/include/megbrain/ir/ops.td
240dccd6f8d42cadfd08c6ca90fe61b1 generated/opdef.h.inl
a79a4058ff18ffd9593ee5db3deef6c4 generated/opdef.cpp.inl
83c179ee7416824fbfab978a097cd4d3 generated/opdef.py.inl
86f70b1052331130f5e4c0ca53e68423 generated/opdef.cpy.inl
40708c56b1f05fdb7d06cc097a300330 ../../src/core/include/megbrain/ir/ops.td
9f3af118c7fe8d0c9db433825d5ad77b generated/opdef.h.inl
4041e44a8ba3cca3b3affa1ed9ed44a2 generated/opdef.cpp.inl
319e1d170c989fe793a4e9c45decefc4 generated/opdef.py.inl
26a18a7593566128ecce76e8f74dcc5d generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h

+ 37
- 0
imperative/tablegen/generated/opdef.cpp.inl View File

@@ -4672,6 +4672,43 @@ OP_TRAIT_REG(MatrixMul, MatrixMul)
.props(MatrixMul_props_impl)
.make_name(MatrixMul_make_name_impl);

MGB_DYN_TYPE_OBJ_FINAL_IMPL(MeshGrid);

namespace {
size_t MeshGrid_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<MeshGrid>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::hash(op_.indexing));
return val;
}
bool MeshGrid_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<MeshGrid>(),
&&b_ = rhs_.cast_final_safe<MeshGrid>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.indexing != b_.indexing) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> MeshGrid_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<MeshGrid>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
props_.emplace_back("indexing", op_.indexing);
return props_;
}
std::string MeshGrid_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<MeshGrid>();
static_cast<void>(op_);
return "MeshGrid";
}
} // anonymous namespace
OP_TRAIT_REG(MeshGrid, MeshGrid)
.hash(MeshGrid_hash_impl)
.is_same_st(MeshGrid_is_same_st_impl)
.props(MeshGrid_props_impl)
.make_name(MeshGrid_make_name_impl);

MGB_DYN_TYPE_OBJ_FINAL_IMPL(MeshIndexing);

namespace {


+ 90
- 0
imperative/tablegen/generated/opdef.cpy.inl View File

@@ -12467,6 +12467,95 @@ void _init_py_MatrixMul(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MatrixMul::typeinfo(), &py_type).second);
}

PyOpDefBegin(MeshGrid) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(MeshGrid)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"indexing", serialization<decltype(opdef.indexing)>::dump(opdef.indexing)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(MeshGrid)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("indexing");
if (iter != state.end()) {
opdef.indexing = serialization<decltype(opdef.indexing)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
// };
PyOpDefEnd(MeshGrid)

int PyOp(MeshGrid)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"indexing", "scope", NULL};
PyObject *indexing = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast<char**>(kwlist), &indexing, &scope))
return -1;

if (indexing) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MeshGrid)*>(self)->inst().indexing =
py::cast<decltype(MeshGrid::indexing)>(py::handle(indexing));
} CATCH_ALL(-1)
}

if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}

return 0;
}

PyGetSetDef PyOp(MeshGrid)::py_getsetters[] = {
{const_cast<char*>("indexing"), py_get_generic(MeshGrid, indexing), py_set_generic(MeshGrid, indexing), const_cast<char*>("indexing"), NULL},
{NULL} /* Sentinel */
};

PyMethodDef PyOp(MeshGrid)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(MeshGrid)::getstate, METH_NOARGS, "MeshGrid getstate"},
{const_cast<char*>("__setstate__"), PyOp(MeshGrid)::setstate, METH_VARARGS, "MeshGrid setstate"},
{NULL} /* Sentinel */
};
void _init_py_MeshGrid(py::module m) {
using py_op = PyOp(MeshGrid);
auto& py_type = PyOpType(MeshGrid);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.MeshGrid";
py_type.tp_basicsize = sizeof(PyOp(MeshGrid));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "MeshGrid";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
PyType_Modified(&py_type);
m.add_object("MeshGrid", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MeshGrid::typeinfo(), &py_type).second);
}

PyOpDefBegin(MeshIndexing) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
@@ -18594,6 +18683,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_MagicMindRuntime(m); \
_init_py_MatrixInverse(m); \
_init_py_MatrixMul(m); \
_init_py_MeshGrid(m); \
_init_py_MeshIndexing(m); \
_init_py_NMSKeep(m); \
_init_py_NvOf(m); \


+ 9
- 0
imperative/tablegen/generated/opdef.h.inl View File

@@ -1262,6 +1262,15 @@ public:
}
};

class MeshGrid : public OpDefImplBase<MeshGrid> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;

public:
std::string indexing;
MeshGrid() = default;
MeshGrid(std::string indexing_, std::string scope_ = {}): indexing(indexing_) { set_scope(scope_); }
};

class MeshIndexing : public OpDefImplBase<MeshIndexing> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;



+ 7
- 0
imperative/tablegen/generated/opdef.py.inl View File

@@ -1365,6 +1365,13 @@ MatrixMulInst
.def_readwrite("dimA", &MatrixMul::dimA)
.def_readwrite("dimB", &MatrixMul::dimB);

py::class_<MeshGrid, std::shared_ptr<MeshGrid>, OpDef> MeshGridInst(m, "MeshGrid");

MeshGridInst
.def(py::init<std::string, std::string>(), py::arg("indexing"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("indexing", &MeshGrid::indexing);

py::class_<MeshIndexing, std::shared_ptr<MeshIndexing>, OpDef> MeshIndexingInst(m, "MeshIndexing");

MeshIndexingInst


+ 5
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -515,4 +515,9 @@ def Dropout: MgbHashableOp<"Dropout", [DropoutParam]> {
let cmpFunction = [{return $0.handle == $1.handle && $0.drop_prob == $1.drop_prob;}];

}
def MeshGrid: MgbHashableOp<"MeshGrid"> {
let extraArguments = (ins
MgbStringAttr:$indexing
);
}
#endif // MGB_OPS

Loading…
Cancel
Save