Browse Source

perf(mge): move Const into C++

GitOrigin-RevId: 31a443cffd
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
d98be08030
10 changed files with 71 additions and 64 deletions
  1. +0
    -40
      imperative/python/megengine/core/ops/special.py
  2. +5
    -5
      imperative/python/megengine/core/tensor/utils.py
  3. +2
    -3
      imperative/python/megengine/functional/math.py
  4. +3
    -4
      imperative/python/megengine/functional/nn.py
  5. +3
    -3
      imperative/python/megengine/functional/tensor.py
  6. +3
    -3
      imperative/python/megengine/functional/tensor_cache.py
  7. +2
    -2
      imperative/python/megengine/traced_module/expr.py
  8. +2
    -0
      imperative/python/src/tensor.cpp
  9. +49
    -4
      imperative/python/src/tensor_utils.cpp
  10. +2
    -0
      imperative/python/src/tensor_utils.h

+ 0
- 40
imperative/python/megengine/core/ops/special.py View File

@@ -1,40 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor


class Const:
def __init__(self, value=None, *, dtype=None, device=None):
self.value = np.asarray(value, dtype=dtype)
self.dtype = dtype
self.device = device

def __call__(self, *reference):
from ...tensor import Tensor

device = self.device

if len(reference) != 0:
reference = reference[0]
assert isinstance(
reference, (SymbolVar, Tensor)
), "Reference should be Tensor or VarNode"

if device is None:
device = reference.device

if isinstance(reference, SymbolVar):
cls = type(reference)
rst = cls(make_const(reference.graph, self.value, device, self.dtype))
return (rst,)

return (Tensor(self.value, self.dtype, self.device, True),)

+ 5
- 5
imperative/python/megengine/core/tensor/utils.py View File

@@ -14,6 +14,7 @@ import numpy as np


from .._imperative_rt import make_const from .._imperative_rt import make_const
from .._imperative_rt.core2 import ( from .._imperative_rt.core2 import (
Const,
SymbolVar, SymbolVar,
Tensor, Tensor,
_get_convert_inputs, _get_convert_inputs,
@@ -28,7 +29,6 @@ from .._imperative_rt.ops import jit_supported
from .._wrap import as_device from .._wrap import as_device
from ..autodiff.grad import Function from ..autodiff.grad import Function
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const
from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype
from .dtype import is_dtype_equal, is_quantize from .dtype import is_dtype_equal, is_quantize


@@ -67,7 +67,7 @@ def convert_single_value(v, *, dtype=None, device=None):
if not is_quantize(v.dtype): if not is_quantize(v.dtype):
v = astype(v, dtype) v = astype(v, dtype)
else: else:
(v,) = Const(v, dtype=dtype, device=device)()
v = Const(v, dtype, device, None)
return v return v




@@ -155,7 +155,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
if ndim != 0 and ndim != 1: if ndim != 0 and ndim != 1:
raise ValueError("ndim != 1 or 0, get : %d" % ndim) raise ValueError("ndim != 1 or 0, get : %d" % ndim)
if not isinstance(x, (Tensor, SymbolVar)): if not isinstance(x, (Tensor, SymbolVar)):
(x,) = Const(x, dtype=dtype, device=device)(*reference)
x = Const(x, dtype, device, reference)
return x return x


if not isinstance(x, collections.abc.Sequence): if not isinstance(x, collections.abc.Sequence):
@@ -166,7 +166,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
if dtype is not None: if dtype is not None:
x = astype(x, dtype) x = astype(x, dtype)
return x return x
(x,) = Const(x, dtype=dtype, device=device)(*reference)
x = Const(x, dtype, device, reference)
return x return x




@@ -337,7 +337,7 @@ def interpret_subgraph(func, dtype, device):
return results return results


def apply_const(value, dtype=dtype, device=device): def apply_const(value, dtype=dtype, device=device):
return Const(value, dtype=dtype, device=device)()[0]
return Const(value, dtype, device, None)


outputs, outputs_has_grad = func(args, apply_expr, apply_const) outputs, outputs_has_grad = func(args, apply_expr, apply_const)
outputs = [ outputs = [


+ 2
- 3
imperative/python/megengine/functional/math.py View File

@@ -10,10 +10,9 @@ import collections
import math import math
from typing import Iterable, Optional, Sequence, Tuple, Union from typing import Iterable, Optional, Sequence, Tuple, Union


from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.core2 import Const, apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.special import Const
from ..core.tensor.array_method import _matmul from ..core.tensor.array_method import _matmul
from ..core.tensor.utils import _normalize_axis from ..core.tensor.utils import _normalize_axis
from ..tensor import Tensor from ..tensor import Tensor
@@ -729,7 +728,7 @@ def topk(
op = builtin.TopK(mode=mode) op = builtin.TopK(mode=mode)


if not isinstance(k, Tensor): if not isinstance(k, Tensor):
(k,) = Const(k, dtype="int32", device=inp.device)()
k = Const(k, "int32", inp.device, None)


if len(inp.shape) == 1: if len(inp.shape) == 1:
if kth_only: if kth_only:


+ 3
- 4
imperative/python/megengine/functional/nn.py View File

@@ -11,7 +11,7 @@ from functools import lru_cache
from typing import NamedTuple, Optional, Sequence, Tuple, Union from typing import NamedTuple, Optional, Sequence, Tuple, Union


from ..core import _config from ..core import _config
from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.core2 import Const, apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from ..core.ops import builtin from ..core.ops import builtin
@@ -26,7 +26,6 @@ from ..core.ops.builtin import (
Reshape, Reshape,
TypeCvt, TypeCvt,
) )
from ..core.ops.special import Const
from ..core.tensor import amp, megbrain_graph from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import ( from ..core.tensor.utils import (
@@ -1317,7 +1316,7 @@ def batch_norm(
raise ValueError("Invalid param_dim {}".format(param_dim)) raise ValueError("Invalid param_dim {}".format(param_dim))


if x is None: if x is None:
(x,) = Const(value, dtype=inp.dtype, device=inp.device)()
x = Const(value, inp.dtype, inp.device, None)
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), x, shape) (result,) = apply(builtin.Broadcast(), x, shape)
return result return result
@@ -1541,7 +1540,7 @@ def sync_batch_norm(


def _make_full_if_none(x, value): def _make_full_if_none(x, value):
if x is None: if x is None:
(x,) = Const(value, dtype=inp.dtype, device=_device)()
x = Const(value, inp.dtype, _device, None)
(result,) = apply(builtin.Broadcast(), x, reduce_shape) (result,) = apply(builtin.Broadcast(), x, reduce_shape)
return result return result
elif x.ndim == 1: elif x.ndim == 1:


+ 3
- 3
imperative/python/megengine/functional/tensor.py View File

@@ -13,6 +13,7 @@ import numpy as np


from ..core._imperative_rt import CompNode from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import ( from ..core._imperative_rt.core2 import (
Const,
SymbolVar, SymbolVar,
apply, apply,
broadcast_cpp, broadcast_cpp,
@@ -24,7 +25,6 @@ from ..core._imperative_rt.core2 import (
from ..core._wrap import as_device from ..core._wrap import as_device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const
from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn
from ..device import get_default_device from ..device import get_default_device
from ..tensor import Tensor from ..tensor import Tensor
@@ -177,7 +177,7 @@ def full(
shape = (shape,) shape = (shape,)
if device is None: if device is None:
device = get_default_device() device = get_default_device()
(x,) = Const(value, dtype=dtype, device=device)()
x = Const(value, dtype, device, None)
if type(shape) in (list, tuple) and len(shape) == 0: if type(shape) in (list, tuple) and len(shape) == 0:
return x return x
return broadcast_to(x, shape) return broadcast_to(x, shape)
@@ -325,7 +325,7 @@ def full_like(
[2 2 2]] [2 2 2]]


""" """
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
x = Const(value, inp.dtype, inp.device, inp)
if inp.ndim == 0: if inp.ndim == 0:
return x return x
return broadcast_to(x, inp.shape) return broadcast_to(x, inp.shape)


+ 3
- 3
imperative/python/megengine/functional/tensor_cache.py View File

@@ -1,4 +1,4 @@
from ..core.ops.special import Const
from ..core._imperative_rt.core2 import Const
from ..jit.tracing import is_tracing from ..jit.tracing import is_tracing


small_tensor_cache = {} small_tensor_cache = {}
@@ -7,11 +7,11 @@ small_tensor_cache = {}
def _get_scalar_tensor_with_value(value, dtype=None, device=None): def _get_scalar_tensor_with_value(value, dtype=None, device=None):
global small_tensor_cache global small_tensor_cache
if is_tracing(): if is_tracing():
(ret,) = Const(value, dtype=dtype, device=device)()
ret = Const(value, dtype, device, None)
else: else:
cache_key = (value, dtype, device) cache_key = (value, dtype, device)
if cache_key not in small_tensor_cache: if cache_key not in small_tensor_cache:
(ret,) = Const(value, dtype=dtype, device=device)()
ret = Const(value, dtype, device, None)
small_tensor_cache[cache_key] = ret small_tensor_cache[cache_key] = ret
else: else:
ret = small_tensor_cache[cache_key] ret = small_tensor_cache[cache_key]


+ 2
- 2
imperative/python/megengine/traced_module/expr.py View File

@@ -16,6 +16,7 @@ from importlib import import_module
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union


from ..core._imperative_rt import OpDef from ..core._imperative_rt import OpDef
from ..core._imperative_rt.core2 import Const
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import ( from ..core._imperative_rt.core2 import (
apply, apply,
@@ -25,7 +26,6 @@ from ..core._imperative_rt.core2 import (
unset_module_tracing, unset_module_tracing,
) )
from ..core.ops.builtin import FakeQuant from ..core.ops.builtin import FakeQuant
from ..core.ops.special import Const
from ..module import Module from ..module import Module
from ..tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
from ..version import __version__ from ..version import __version__
@@ -764,7 +764,7 @@ class Constant(Expr):


def interpret(self, *inputs): def interpret(self, *inputs):
if isinstance(self.value, RawTensor): if isinstance(self.value, RawTensor):
return Const(self.value.numpy())()
return (Const(self.value.numpy(), None, None, None),)
return (self.value,) return (self.value,)


def __repr__(self): def __repr__(self):


+ 2
- 0
imperative/python/src/tensor.cpp View File

@@ -639,6 +639,7 @@ WRAP_FUNC_PY35(squeeze_cpp);
WRAP_FUNC_PY35(transpose_cpp); WRAP_FUNC_PY35(transpose_cpp);
WRAP_FUNC_PY35(broadcast_cpp); WRAP_FUNC_PY35(broadcast_cpp);
WRAP_FUNC_PY35(reshape_cpp); WRAP_FUNC_PY35(reshape_cpp);
WRAP_FUNC_PY35(Const);
#undef WRAP_FUNC_PY35 #undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \ #define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
@@ -777,6 +778,7 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE(transpose_cpp, transpose_cpp), MGE_PY_INTERFACE(transpose_cpp, transpose_cpp),
MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp),
MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), MGE_PY_INTERFACE(reshape_cpp, reshape_cpp),
MGE_PY_INTERFACE(Const, Const),
{nullptr, nullptr, 0, nullptr}}; {nullptr, nullptr, 0, nullptr}};
for (auto&& def : method_defs) { for (auto&& def : method_defs) {
if (def.ml_meth != nullptr) { if (def.ml_meth != nullptr) {


+ 49
- 4
imperative/python/src/tensor_utils.cpp View File

@@ -94,7 +94,7 @@ bool is_bool_dtype(PyObject* args) {
} }


py::object _Const( py::object _Const(
py::handle value, py::handle dtype, py::handle device, py::handle ref) {
py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) {
py::object val = py::reinterpret_borrow<py::object>(value); py::object val = py::reinterpret_borrow<py::object>(value);
if (PyArray_Check(value.ptr())) { if (PyArray_Check(value.ptr())) {
py::tuple strides = py::tuple strides =
@@ -107,21 +107,56 @@ py::object _Const(
} }
if (need_squeeze) { if (need_squeeze) {
val = py::reinterpret_borrow<py::array>(value); val = py::reinterpret_borrow<py::array>(value);
py::object orig_shp = val.attr("shape");
val = val.attr("squeeze")(); val = val.attr("squeeze")();
val = val.attr("reshape")(val.attr("shape"));
val = val.attr("reshape")(orig_shp);
} }
} }
py::object ref;
if (py::isinstance<py::tuple>(ref_hdl)) {
py::tuple tup = py::reinterpret_borrow<py::tuple>(ref_hdl);
if (tup.size()) {
ref = tup[0];
} else {
ref = py::none();
}
} else {
ref = py::reinterpret_borrow<py::object>(ref_hdl);
}
if (py::isinstance<PySymbolVar>(ref)) { if (py::isinstance<PySymbolVar>(ref)) {
auto ref_var = ref.cast<PySymbolVar*>(); auto ref_var = ref.cast<PySymbolVar*>();
auto* graph = ref_var->m_node->owner_graph(); auto* graph = ref_var->m_node->owner_graph();
auto cn = device.cast<CompNode>();
CompNode cn;
if (device.ptr() == Py_None) {
cn = ref_var->m_node->comp_node();
} else {
cn = device.cast<CompNode>();
}
OperatorNodeConfig config(cn); OperatorNodeConfig config(cn);
auto hv = npy::np2tensor( auto hv = npy::np2tensor(
val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
auto typeobj = ref.get_type(); auto typeobj = ref.get_type();
return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
} }
py::tuple tup = py::make_tuple(val, dtype, device, true, false, py::none());
py::object device_obj;
if (device.ptr() == Py_None) {
device_obj = py::cast(CompNode::load(get_default_device()));
} else if (py::isinstance<py::str>(device)) {
py::object dmap =
getattr(py::reinterpret_borrow<py::object>((PyObject*)py_tensor_type),
"dmap_callback");
if (dmap.ptr() != Py_None) {
device_obj = dmap(device);
py::print(device_obj);
} else {
device_obj = py::cast(CompNode::load(device.cast<std::string>()));
}
} else if (py::isinstance<CompNode>(device)) {
device_obj = py::reinterpret_borrow<py::object>(device);
} else {
device_obj = getattr(device, "_cn");
}
py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none());
return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr);
} }


@@ -1107,4 +1142,14 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET(nullptr) PYEXT17_TRANSLATE_EXC_RET(nullptr)
} }


PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _Const(py::handle(args[0]), py::handle(args[1]), py::handle(args[2]),
py::handle(args[3]))
.release()
.ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}

} // namespace mgb::imperative::python } // namespace mgb::imperative::python

+ 2
- 0
imperative/python/src/tensor_utils.h View File

@@ -20,4 +20,6 @@ PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs);


PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs);


PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs);

} // namespace mgb::imperative::python } // namespace mgb::imperative::python

Loading…
Cancel
Save