@@ -1,40 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from .._imperative_rt import make_const | |||
from .._imperative_rt.core2 import SymbolVar, Tensor | |||
class Const: | |||
def __init__(self, value=None, *, dtype=None, device=None): | |||
self.value = np.asarray(value, dtype=dtype) | |||
self.dtype = dtype | |||
self.device = device | |||
def __call__(self, *reference): | |||
from ...tensor import Tensor | |||
device = self.device | |||
if len(reference) != 0: | |||
reference = reference[0] | |||
assert isinstance( | |||
reference, (SymbolVar, Tensor) | |||
), "Reference should be Tensor or VarNode" | |||
if device is None: | |||
device = reference.device | |||
if isinstance(reference, SymbolVar): | |||
cls = type(reference) | |||
rst = cls(make_const(reference.graph, self.value, device, self.dtype)) | |||
return (rst,) | |||
return (Tensor(self.value, self.dtype, self.device, True),) |
@@ -14,6 +14,7 @@ import numpy as np | |||
from .._imperative_rt import make_const | |||
from .._imperative_rt.core2 import ( | |||
Const, | |||
SymbolVar, | |||
Tensor, | |||
_get_convert_inputs, | |||
@@ -28,7 +29,6 @@ from .._imperative_rt.ops import jit_supported | |||
from .._wrap import as_device | |||
from ..autodiff.grad import Function | |||
from ..ops import builtin | |||
from ..ops.special import Const | |||
from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype | |||
from .dtype import is_dtype_equal, is_quantize | |||
@@ -67,7 +67,7 @@ def convert_single_value(v, *, dtype=None, device=None): | |||
if not is_quantize(v.dtype): | |||
v = astype(v, dtype) | |||
else: | |||
(v,) = Const(v, dtype=dtype, device=device)() | |||
v = Const(v, dtype, device, None) | |||
return v | |||
@@ -155,7 +155,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
if ndim != 0 and ndim != 1: | |||
raise ValueError("ndim != 1 or 0, get : %d" % ndim) | |||
if not isinstance(x, (Tensor, SymbolVar)): | |||
(x,) = Const(x, dtype=dtype, device=device)(*reference) | |||
x = Const(x, dtype, device, reference) | |||
return x | |||
if not isinstance(x, collections.abc.Sequence): | |||
@@ -166,7 +166,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
if dtype is not None: | |||
x = astype(x, dtype) | |||
return x | |||
(x,) = Const(x, dtype=dtype, device=device)(*reference) | |||
x = Const(x, dtype, device, reference) | |||
return x | |||
@@ -337,7 +337,7 @@ def interpret_subgraph(func, dtype, device): | |||
return results | |||
def apply_const(value, dtype=dtype, device=device): | |||
return Const(value, dtype=dtype, device=device)()[0] | |||
return Const(value, dtype, device, None) | |||
outputs, outputs_has_grad = func(args, apply_expr, apply_const) | |||
outputs = [ | |||
@@ -10,10 +10,9 @@ import collections | |||
import math | |||
from typing import Iterable, Optional, Sequence, Tuple, Union | |||
from ..core._imperative_rt.core2 import apply, dtype_promotion | |||
from ..core._imperative_rt.core2 import Const, apply, dtype_promotion | |||
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | |||
from ..core.ops import builtin | |||
from ..core.ops.special import Const | |||
from ..core.tensor.array_method import _matmul | |||
from ..core.tensor.utils import _normalize_axis | |||
from ..tensor import Tensor | |||
@@ -729,7 +728,7 @@ def topk( | |||
op = builtin.TopK(mode=mode) | |||
if not isinstance(k, Tensor): | |||
(k,) = Const(k, dtype="int32", device=inp.device)() | |||
k = Const(k, "int32", inp.device, None) | |||
if len(inp.shape) == 1: | |||
if kth_only: | |||
@@ -11,7 +11,7 @@ from functools import lru_cache | |||
from typing import NamedTuple, Optional, Sequence, Tuple, Union | |||
from ..core import _config | |||
from ..core._imperative_rt.core2 import apply, dtype_promotion | |||
from ..core._imperative_rt.core2 import Const, apply, dtype_promotion | |||
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | |||
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | |||
from ..core.ops import builtin | |||
@@ -26,7 +26,6 @@ from ..core.ops.builtin import ( | |||
Reshape, | |||
TypeCvt, | |||
) | |||
from ..core.ops.special import Const | |||
from ..core.tensor import amp, megbrain_graph | |||
from ..core.tensor.array_method import _elwise_apply | |||
from ..core.tensor.utils import ( | |||
@@ -1317,7 +1316,7 @@ def batch_norm( | |||
raise ValueError("Invalid param_dim {}".format(param_dim)) | |||
if x is None: | |||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)() | |||
x = Const(value, inp.dtype, inp.device, None) | |||
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | |||
(result,) = apply(builtin.Broadcast(), x, shape) | |||
return result | |||
@@ -1541,7 +1540,7 @@ def sync_batch_norm( | |||
def _make_full_if_none(x, value): | |||
if x is None: | |||
(x,) = Const(value, dtype=inp.dtype, device=_device)() | |||
x = Const(value, inp.dtype, _device, None) | |||
(result,) = apply(builtin.Broadcast(), x, reduce_shape) | |||
return result | |||
elif x.ndim == 1: | |||
@@ -13,6 +13,7 @@ import numpy as np | |||
from ..core._imperative_rt import CompNode | |||
from ..core._imperative_rt.core2 import ( | |||
Const, | |||
SymbolVar, | |||
apply, | |||
broadcast_cpp, | |||
@@ -24,7 +25,6 @@ from ..core._imperative_rt.core2 import ( | |||
from ..core._wrap import as_device | |||
from ..core.ops import builtin | |||
from ..core.ops.builtin import Copy, Identity | |||
from ..core.ops.special import Const | |||
from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn | |||
from ..device import get_default_device | |||
from ..tensor import Tensor | |||
@@ -177,7 +177,7 @@ def full( | |||
shape = (shape,) | |||
if device is None: | |||
device = get_default_device() | |||
(x,) = Const(value, dtype=dtype, device=device)() | |||
x = Const(value, dtype, device, None) | |||
if type(shape) in (list, tuple) and len(shape) == 0: | |||
return x | |||
return broadcast_to(x, shape) | |||
@@ -325,7 +325,7 @@ def full_like( | |||
[2 2 2]] | |||
""" | |||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||
x = Const(value, inp.dtype, inp.device, inp) | |||
if inp.ndim == 0: | |||
return x | |||
return broadcast_to(x, inp.shape) | |||
@@ -1,4 +1,4 @@ | |||
from ..core.ops.special import Const | |||
from ..core._imperative_rt.core2 import Const | |||
from ..jit.tracing import is_tracing | |||
small_tensor_cache = {} | |||
@@ -7,11 +7,11 @@ small_tensor_cache = {} | |||
def _get_scalar_tensor_with_value(value, dtype=None, device=None): | |||
global small_tensor_cache | |||
if is_tracing(): | |||
(ret,) = Const(value, dtype=dtype, device=device)() | |||
ret = Const(value, dtype, device, None) | |||
else: | |||
cache_key = (value, dtype, device) | |||
if cache_key not in small_tensor_cache: | |||
(ret,) = Const(value, dtype=dtype, device=device)() | |||
ret = Const(value, dtype, device, None) | |||
small_tensor_cache[cache_key] = ret | |||
else: | |||
ret = small_tensor_cache[cache_key] | |||
@@ -16,6 +16,7 @@ from importlib import import_module | |||
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union | |||
from ..core._imperative_rt import OpDef | |||
from ..core._imperative_rt.core2 import Const | |||
from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
from ..core._imperative_rt.core2 import ( | |||
apply, | |||
@@ -25,7 +26,6 @@ from ..core._imperative_rt.core2 import ( | |||
unset_module_tracing, | |||
) | |||
from ..core.ops.builtin import FakeQuant | |||
from ..core.ops.special import Const | |||
from ..module import Module | |||
from ..tensor import Parameter, Tensor | |||
from ..version import __version__ | |||
@@ -764,7 +764,7 @@ class Constant(Expr): | |||
def interpret(self, *inputs): | |||
if isinstance(self.value, RawTensor): | |||
return Const(self.value.numpy())() | |||
return (Const(self.value.numpy(), None, None, None),) | |||
return (self.value,) | |||
def __repr__(self): | |||
@@ -639,6 +639,7 @@ WRAP_FUNC_PY35(squeeze_cpp); | |||
WRAP_FUNC_PY35(transpose_cpp); | |||
WRAP_FUNC_PY35(broadcast_cpp); | |||
WRAP_FUNC_PY35(reshape_cpp); | |||
WRAP_FUNC_PY35(Const); | |||
#undef WRAP_FUNC_PY35 | |||
#define MGE_PY_INTERFACE(NAME, FUNC) \ | |||
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | |||
@@ -777,6 +778,7 @@ void init_tensor(py::module m) { | |||
MGE_PY_INTERFACE(transpose_cpp, transpose_cpp), | |||
MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), | |||
MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), | |||
MGE_PY_INTERFACE(Const, Const), | |||
{nullptr, nullptr, 0, nullptr}}; | |||
for (auto&& def : method_defs) { | |||
if (def.ml_meth != nullptr) { | |||
@@ -94,7 +94,7 @@ bool is_bool_dtype(PyObject* args) { | |||
} | |||
py::object _Const( | |||
py::handle value, py::handle dtype, py::handle device, py::handle ref) { | |||
py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) { | |||
py::object val = py::reinterpret_borrow<py::object>(value); | |||
if (PyArray_Check(value.ptr())) { | |||
py::tuple strides = | |||
@@ -107,21 +107,56 @@ py::object _Const( | |||
} | |||
if (need_squeeze) { | |||
val = py::reinterpret_borrow<py::array>(value); | |||
py::object orig_shp = val.attr("shape"); | |||
val = val.attr("squeeze")(); | |||
val = val.attr("reshape")(val.attr("shape")); | |||
val = val.attr("reshape")(orig_shp); | |||
} | |||
} | |||
py::object ref; | |||
if (py::isinstance<py::tuple>(ref_hdl)) { | |||
py::tuple tup = py::reinterpret_borrow<py::tuple>(ref_hdl); | |||
if (tup.size()) { | |||
ref = tup[0]; | |||
} else { | |||
ref = py::none(); | |||
} | |||
} else { | |||
ref = py::reinterpret_borrow<py::object>(ref_hdl); | |||
} | |||
if (py::isinstance<PySymbolVar>(ref)) { | |||
auto ref_var = ref.cast<PySymbolVar*>(); | |||
auto* graph = ref_var->m_node->owner_graph(); | |||
auto cn = device.cast<CompNode>(); | |||
CompNode cn; | |||
if (device.ptr() == Py_None) { | |||
cn = ref_var->m_node->comp_node(); | |||
} else { | |||
cn = device.cast<CompNode>(); | |||
} | |||
OperatorNodeConfig config(cn); | |||
auto hv = npy::np2tensor( | |||
val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); | |||
auto typeobj = ref.get_type(); | |||
return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); | |||
} | |||
py::tuple tup = py::make_tuple(val, dtype, device, true, false, py::none()); | |||
py::object device_obj; | |||
if (device.ptr() == Py_None) { | |||
device_obj = py::cast(CompNode::load(get_default_device())); | |||
} else if (py::isinstance<py::str>(device)) { | |||
py::object dmap = | |||
getattr(py::reinterpret_borrow<py::object>((PyObject*)py_tensor_type), | |||
"dmap_callback"); | |||
if (dmap.ptr() != Py_None) { | |||
device_obj = dmap(device); | |||
py::print(device_obj); | |||
} else { | |||
device_obj = py::cast(CompNode::load(device.cast<std::string>())); | |||
} | |||
} else if (py::isinstance<CompNode>(device)) { | |||
device_obj = py::reinterpret_borrow<py::object>(device); | |||
} else { | |||
device_obj = getattr(device, "_cn"); | |||
} | |||
py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none()); | |||
return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); | |||
} | |||
@@ -1107,4 +1142,14 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { | |||
try { | |||
return _Const(py::handle(args[0]), py::handle(args[1]), py::handle(args[2]), | |||
py::handle(args[3])) | |||
.release() | |||
.ptr(); | |||
} | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
} // namespace mgb::imperative::python |
@@ -20,4 +20,6 @@ PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs); | |||
} // namespace mgb::imperative::python |