@@ -16,6 +16,7 @@ import numpy as np | |||||
from .. import _config | from .. import _config | ||||
from .._imperative_rt.common import CompNode | from .._imperative_rt.common import CompNode | ||||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion | from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion | ||||
from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | |||||
from ..ops import builtin | from ..ops import builtin | ||||
from . import amp | from . import amp | ||||
from .indexing import getitem, setitem | from .indexing import getitem, setitem | ||||
@@ -508,12 +509,8 @@ def _reduce(mode): | |||||
elif self.dtype == np.bool_: | elif self.dtype == np.bool_: | ||||
data = data.astype("int32") | data = data.astype("int32") | ||||
if axis is None: | if axis is None: | ||||
data = data.reshape(-1) | |||||
assert not keepdims, "can not set axis=None and keepdims=True" | assert not keepdims, "can not set axis=None and keepdims=True" | ||||
op = builtin.Reduce(mode=mode, axis=0) | |||||
(result,) = apply(op, data) | |||||
result = _remove_axis(result, 0) | |||||
result = _reduce_to_scalar(builtin.Reduce(mode=mode), data) | |||||
elif isinstance(axis, collections.abc.Iterable): | elif isinstance(axis, collections.abc.Iterable): | ||||
axis = _normalize_axis(self.ndim, axis, reverse=True) | axis = _normalize_axis(self.ndim, axis, reverse=True) | ||||
for ai in axis: | for ai in axis: | ||||
@@ -69,7 +69,7 @@ class SGD(Optimizer): | |||||
inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")) | inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")) | ||||
if inplace_mode: | if inplace_mode: | ||||
_neg_lr = tensor(-lr, dtype="float32") | _neg_lr = tensor(-lr, dtype="float32") | ||||
c1 = tensor([1.0]) | |||||
c1 = tensor(1.0) | |||||
for param in param_group["params"]: | for param in param_group["params"]: | ||||
if param.grad is None: | if param.grad is None: | ||||
@@ -84,14 +84,15 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
device: str = None, | device: str = None, | ||||
is_const: bool = False, | is_const: bool = False, | ||||
no_cache: bool = False, | no_cache: bool = False, | ||||
name: str = "", | |||||
name: str = None, | |||||
): | ): | ||||
if name is None: | if name is None: | ||||
name = "" | name = "" | ||||
else: | |||||
self._set_name(name) | |||||
self._custom_name = name | self._custom_name = name | ||||
self._name = name | self._name = name | ||||
self._short_name = name | self._short_name = name | ||||
self._set_name(self._name) | |||||
self._prefix = None | self._prefix = None | ||||
@property | @property | ||||
@@ -46,17 +46,17 @@ void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { | |||||
if (args[1] != Py_None) { | if (args[1] != Py_None) { | ||||
callback = py::reinterpret_borrow<py::object>(args[1]); | callback = py::reinterpret_borrow<py::object>(args[1]); | ||||
} | } | ||||
GenericFunction generic_callback = | |||||
[=](Span<ValueRef> inputs) -> std::vector<ValueRef> { | |||||
GenericFunction generic_callback = [=](Span<ValueRef> inputs) -> ValueRefList { | |||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
if (callback) { | if (callback) { | ||||
callback(TensorWrapper::make(py_tensor_type, inputs[0])); | callback(TensorWrapper::make(py_tensor_type, inputs[0])); | ||||
} | } | ||||
return {}; | return {}; | ||||
}; | }; | ||||
tw->m_tensor->reset(imperative::apply( | |||||
auto attached_value = imperative::apply( | |||||
AttachGrad(m_key), tw->m_tensor->data(), | AttachGrad(m_key), tw->m_tensor->data(), | ||||
FunctionValue::make(generic_callback))[0]); | |||||
FunctionValue::make(generic_callback))[0]; | |||||
tw->m_tensor->reset(attached_value); | |||||
} | } | ||||
void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) { | void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) { | ||||
@@ -98,7 +98,7 @@ ValueRef make_empty_tensor( | |||||
return res; | return res; | ||||
} | } | ||||
std::optional<std::vector<ValueRef>> elemwise_grad_rule( | |||||
std::optional<ValueRefList> elemwise_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
CustomBackward& backward) { | CustomBackward& backward) { | ||||
auto& elemwise = op.cast_final_safe<Elemwise>(); | auto& elemwise = op.cast_final_safe<Elemwise>(); | ||||
@@ -117,7 +117,7 @@ std::optional<std::vector<ValueRef>> elemwise_grad_rule( | |||||
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | ||||
mgb_assert(grads.size() == 1); | mgb_assert(grads.size() == 1); | ||||
ValueRef grad = grads[0]; | ValueRef grad = grads[0]; | ||||
std::vector<ValueRef> ret(2); | |||||
ValueRefList ret(2); | |||||
if (!grad) { | if (!grad) { | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -132,7 +132,7 @@ std::optional<std::vector<ValueRef>> elemwise_grad_rule( | |||||
return imperative::apply(ApplyOp(op), inputs); | return imperative::apply(ApplyOp(op), inputs); | ||||
} | } | ||||
std::optional<std::vector<ValueRef>> reshape_grad_rule( | |||||
std::optional<ValueRefList> reshape_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
CustomBackward& backward) { | CustomBackward& backward) { | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
@@ -147,7 +147,7 @@ std::optional<std::vector<ValueRef>> reshape_grad_rule( | |||||
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | ||||
mgb_assert(grads.size() == 1); | mgb_assert(grads.size() == 1); | ||||
ValueRef grad = grads[0]; | ValueRef grad = grads[0]; | ||||
std::vector<ValueRef> ret(2); | |||||
ValueRefList ret(2); | |||||
if (!grad) { | if (!grad) { | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -162,7 +162,7 @@ std::optional<std::vector<ValueRef>> reshape_grad_rule( | |||||
return imperative::apply(ApplyOp(op), inputs); | return imperative::apply(ApplyOp(op), inputs); | ||||
} | } | ||||
std::optional<std::vector<ValueRef>> subtensor_grad_rule( | |||||
std::optional<ValueRefList> subtensor_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
CustomBackward& backward) { | CustomBackward& backward) { | ||||
auto&& subtensor = op.cast_final_safe<Subtensor>(); | auto&& subtensor = op.cast_final_safe<Subtensor>(); | ||||
@@ -180,9 +180,9 @@ std::optional<std::vector<ValueRef>> subtensor_grad_rule( | |||||
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | ||||
mgb_assert(grads.size() == 1); | mgb_assert(grads.size() == 1); | ||||
ValueRef grad = grads[0]; | ValueRef grad = grads[0]; | ||||
std::vector<ValueRef> ret(1); | |||||
ValueRefList ret(1); | |||||
if (grad && inputs[0]) { | if (grad && inputs[0]) { | ||||
SmallVector<ValueRef> args_(inputs.size() + 1); | |||||
ValueRefList args_(inputs.size() + 1); | |||||
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | ||||
args_[0] = zeros; | args_[0] = zeros; | ||||
args_[1] = grad; | args_[1] = grad; | ||||
@@ -197,7 +197,7 @@ std::optional<std::vector<ValueRef>> subtensor_grad_rule( | |||||
return imperative::apply(ApplyOp(op), inputs); | return imperative::apply(ApplyOp(op), inputs); | ||||
} | } | ||||
std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule( | |||||
std::optional<ValueRefList> indexingMultiAxisVec_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
CustomBackward& backward) { | CustomBackward& backward) { | ||||
auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>(); | auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>(); | ||||
@@ -215,9 +215,9 @@ std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule( | |||||
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | ||||
mgb_assert(grads.size() == 1); | mgb_assert(grads.size() == 1); | ||||
ValueRef grad = grads[0]; | ValueRef grad = grads[0]; | ||||
std::vector<ValueRef> ret(1); | |||||
ValueRefList ret(1); | |||||
if (grad && inputs[0]) { | if (grad && inputs[0]) { | ||||
SmallVector<ValueRef> args_(inputs.size() + 1); | |||||
ValueRefList args_(inputs.size() + 1); | |||||
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | ||||
args_[0] = zeros; | args_[0] = zeros; | ||||
args_[1] = grad; | args_[1] = grad; | ||||
@@ -232,7 +232,7 @@ std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule( | |||||
return imperative::apply(ApplyOp(op), inputs); | return imperative::apply(ApplyOp(op), inputs); | ||||
} | } | ||||
std::optional<std::vector<ValueRef>> reduce_grad_rule( | |||||
std::optional<ValueRefList> reduce_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
CustomBackward& backward) { | CustomBackward& backward) { | ||||
auto& reduce = op.cast_final_safe<Reduce>(); | auto& reduce = op.cast_final_safe<Reduce>(); | ||||
@@ -251,7 +251,7 @@ std::optional<std::vector<ValueRef>> reduce_grad_rule( | |||||
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | ||||
mgb_assert(grads.size() == 1); | mgb_assert(grads.size() == 1); | ||||
ValueRef grad = grads[0]; | ValueRef grad = grads[0]; | ||||
std::vector<ValueRef> ret(1); | |||||
ValueRefList ret(1); | |||||
if (grad && shapes[0]) { | if (grad && shapes[0]) { | ||||
ret[0] = broadcast_to(grad, shapes[0]); | ret[0] = broadcast_to(grad, shapes[0]); | ||||
} | } | ||||
@@ -261,7 +261,7 @@ std::optional<std::vector<ValueRef>> reduce_grad_rule( | |||||
return imperative::apply(ApplyOp(op), inputs); | return imperative::apply(ApplyOp(op), inputs); | ||||
} | } | ||||
std::optional<std::vector<ValueRef>> addAxis_grad_rule( | |||||
std::optional<ValueRefList> addAxis_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
CustomBackward& backward) { | CustomBackward& backward) { | ||||
auto&& addAxis = op.cast_final_safe<AddAxis>(); | auto&& addAxis = op.cast_final_safe<AddAxis>(); | ||||
@@ -274,7 +274,7 @@ std::optional<std::vector<ValueRef>> addAxis_grad_rule( | |||||
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | ||||
mgb_assert(grads.size() == 1); | mgb_assert(grads.size() == 1); | ||||
ValueRef grad = grads[0]; | ValueRef grad = grads[0]; | ||||
std::vector<ValueRef> ret(1); | |||||
ValueRefList ret(1); | |||||
if (grad && flag_) { | if (grad && flag_) { | ||||
ret[0] = imperative::apply(*grad_op_, grad)[0]; | ret[0] = imperative::apply(*grad_op_, grad)[0]; | ||||
} | } | ||||
@@ -284,7 +284,7 @@ std::optional<std::vector<ValueRef>> addAxis_grad_rule( | |||||
return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
} | } | ||||
std::optional<std::vector<ValueRef>> removeAxis_grad_rule( | |||||
std::optional<ValueRefList> removeAxis_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
CustomBackward& backward) { | CustomBackward& backward) { | ||||
auto&& removeAxis = op.cast_final_safe<RemoveAxis>(); | auto&& removeAxis = op.cast_final_safe<RemoveAxis>(); | ||||
@@ -297,7 +297,7 @@ std::optional<std::vector<ValueRef>> removeAxis_grad_rule( | |||||
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | ||||
mgb_assert(grads.size() == 1); | mgb_assert(grads.size() == 1); | ||||
ValueRef grad = grads[0]; | ValueRef grad = grads[0]; | ||||
std::vector<ValueRef> ret(1); | |||||
ValueRefList ret(1); | |||||
if (grad && flag_) { | if (grad && flag_) { | ||||
ret[0] = imperative::apply(*grad_op_, grad)[0]; | ret[0] = imperative::apply(*grad_op_, grad)[0]; | ||||
} | } | ||||
@@ -307,7 +307,7 @@ std::optional<std::vector<ValueRef>> removeAxis_grad_rule( | |||||
return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
} | } | ||||
std::optional<std::vector<ValueRef>> fastpathcopy_grad_rule( | |||||
std::optional<ValueRefList> fastpathcopy_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
CustomBackward& backward) { | CustomBackward& backward) { | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
@@ -316,7 +316,7 @@ std::optional<std::vector<ValueRef>> fastpathcopy_grad_rule( | |||||
maker.backward([](Span<ValueRef> grads) { | maker.backward([](Span<ValueRef> grads) { | ||||
mgb_assert(grads.size() == 1); | mgb_assert(grads.size() == 1); | ||||
ValueRef grad = grads[0]; | ValueRef grad = grads[0]; | ||||
std::vector<ValueRef> ret(1); | |||||
ValueRefList ret(1); | |||||
if (grad) { | if (grad) { | ||||
ret[0] = grad; | ret[0] = grad; | ||||
} | } | ||||
@@ -25,24 +25,23 @@ private: | |||||
py::function m_hook_fn; | py::function m_hook_fn; | ||||
int m_enabled = 0; | int m_enabled = 0; | ||||
std::vector<ValueRef> apply_module_trace_hook( | |||||
const OpDef& op, Span<ValueRef> input_values) { | |||||
ValueRefList apply_module_trace_hook(const OpDef& op, Span<ValueRef> input_values) { | |||||
py::list input_tws; | py::list input_tws; | ||||
for (auto&& input_value : input_values) { | for (auto&& input_value : input_values) { | ||||
input_tws.append(TensorWrapper::make(py_tensor_type, input_value)); | input_tws.append(TensorWrapper::make(py_tensor_type, input_value)); | ||||
} | } | ||||
py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws); | py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws); | ||||
std::vector<ValueRef> outputs; | |||||
ValueRefList outputs(output_tws.size()); | |||||
auto it = outputs.begin(); | |||||
for (auto&& output_tw : output_tws) { | for (auto&& output_tw : output_tws) { | ||||
outputs.push_back( | |||||
TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data()); | |||||
*(it++) = TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data(); | |||||
} | } | ||||
return outputs; | return outputs; | ||||
} | } | ||||
public: | public: | ||||
ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} | ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} | ||||
std::vector<ValueRef> apply_transformation( | |||||
ValueRefList apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) override { | const Operator& op, Span<ValueRef> inputs) override { | ||||
if (op.is<ApplyOp>() && m_enabled > 0) { | if (op.is<ApplyOp>() && m_enabled > 0) { | ||||
auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs); | auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs); | ||||
@@ -87,7 +87,7 @@ PyObject* py_apply( | |||||
--nargs; | --nargs; | ||||
auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>(); | auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>(); | ||||
SmallVector<ValueRef, 64> tensors(nargs); | |||||
SmallVector<ValueRef, 8> tensors(nargs); | |||||
if (py::isinstance<PySymbolVar>(py::handle(args[0]))) { | if (py::isinstance<PySymbolVar>(py::handle(args[0]))) { | ||||
// swap to a special context to reuse scalar handle | // swap to a special context to reuse scalar handle | ||||
@@ -100,16 +100,15 @@ PyObject* py_apply( | |||||
Transformation::top()); | Transformation::top()); | ||||
std::make_shared<ScalarTransformation>()->register_at( | std::make_shared<ScalarTransformation>()->register_at( | ||||
Transformation::top()); | Transformation::top()); | ||||
SmallVector<ValueRef> inputs(nargs); | |||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
auto* py_input = py::handle(args[i]).cast<PySymbolVar*>(); | auto* py_input = py::handle(args[i]).cast<PySymbolVar*>(); | ||||
ValueRef input = SymbolValue::make(py_input->m_node); | ValueRef input = SymbolValue::make(py_input->m_node); | ||||
if (py_input->is_scalar) { | if (py_input->is_scalar) { | ||||
input = ScalarValue::make(input); | input = ScalarValue::make(input); | ||||
} | } | ||||
inputs[i] = input; | |||||
tensors[i] = input; | |||||
} | } | ||||
auto outputs = imperative::apply(*op, inputs); | |||||
auto outputs = imperative::apply(*op, tensors); | |||||
auto ret = pybind11::tuple(outputs.size()); | auto ret = pybind11::tuple(outputs.size()); | ||||
auto typeobj = py::handle(args[0]).get_type(); | auto typeobj = py::handle(args[0]).get_type(); | ||||
for (size_t i = 0; i < outputs.size(); ++i) { | for (size_t i = 0; i < outputs.size(); ++i) { | ||||
@@ -140,7 +139,7 @@ PyObject* py_apply( | |||||
} | } | ||||
} | } | ||||
auto outputs = imperative::apply(ApplyOp(*op), {tensors.data(), nargs}); | |||||
auto outputs = imperative::apply(*op, tensors); | |||||
size_t nout = outputs.size(); | size_t nout = outputs.size(); | ||||
auto ret = py::tuple(nout); | auto ret = py::tuple(nout); | ||||
for (size_t i = 0; i < nout; ++i) { | for (size_t i = 0; i < nout; ++i) { | ||||
@@ -214,16 +213,10 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
if (!name.empty()) { | if (!name.empty()) { | ||||
m_tensor->reset( | m_tensor->reset( | ||||
imperative::apply(RenameValue(name), m_tensor->data())[0]); | imperative::apply(RenameValue(name), m_tensor->data())[0]); | ||||
mgb_assert( | |||||
((std::string&)*m_tensor->data().name()) == name, | |||||
"result name incorrect"); | |||||
} | |||||
if (data.ndim() == 0) { | |||||
mgb_assert(m_tensor->is_scalar(), "result should be scalar"); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
mgb_assert(m_tensor->data()); | |||||
} | } | ||||
PyObject* TensorWrapper::module_trace_info() { | PyObject* TensorWrapper::module_trace_info() { | ||||
@@ -1384,15 +1377,20 @@ void init_tensor(py::module m) { | |||||
std::function<bool(py::object, py::object)> array_comparator; | std::function<bool(py::object, py::object)> array_comparator; | ||||
bool compare_value(ValueRef lhs, ValueRef rhs) { | bool compare_value(ValueRef lhs, ValueRef rhs) { | ||||
if (!lhs.shape()->eq(*rhs.shape())) { | |||||
auto lvalue = lhs.numpy(); | |||||
auto rvalue = rhs.numpy(); | |||||
if (lvalue->shape() != rvalue->shape()) { | |||||
return false; | return false; | ||||
} | } | ||||
HostTensorND lvalue = lhs.numpy()->as_nd(true); | |||||
HostTensorND rvalue = rhs.numpy()->as_nd(true); | |||||
if (lvalue->shape().is_scalar()) { | |||||
return lvalue->item() == rvalue->item(); | |||||
} | |||||
HostTensorND lnd = lvalue->as_nd(true); | |||||
HostTensorND rnd = rvalue->as_nd(true); | |||||
auto larr = py::reinterpret_steal<py::array>( | auto larr = py::reinterpret_steal<py::array>( | ||||
npy::ndarray_from_tensor(lvalue, npy::ShareType::TRY_SHARE)); | |||||
npy::ndarray_from_tensor(lnd, npy::ShareType::TRY_SHARE)); | |||||
auto rarr = py::reinterpret_steal<py::array>( | auto rarr = py::reinterpret_steal<py::array>( | ||||
npy::ndarray_from_tensor(rvalue, npy::ShareType::TRY_SHARE)); | |||||
npy::ndarray_from_tensor(rnd, npy::ShareType::TRY_SHARE)); | |||||
return array_comparator(larr, rarr); | return array_comparator(larr, rarr); | ||||
} | } | ||||
@@ -1539,6 +1537,19 @@ void init_tensor(py::module m) { | |||||
} | } | ||||
}); | }); | ||||
m.def("reduce_to_scalar", [](py::object op, py::object tensor) { | |||||
auto* tw = TensorWrapper::try_cast(tensor.ptr()); | |||||
auto make_scalar_shape = [&](CompNode device) { | |||||
return imperative::apply( | |||||
CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}), | |||||
HostStorage::make(device))[0]; | |||||
}; | |||||
auto output = imperative::apply( | |||||
*op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data(), | |||||
make_scalar_shape(tw->m_tensor->comp_node()))[0]; | |||||
return TensorWrapper::make(py_tensor_type, output); | |||||
}); | |||||
m.def("name_tensor", [](std::string name, py::object tensor) { | m.def("name_tensor", [](std::string name, py::object tensor) { | ||||
auto* tw = TensorWrapper::try_cast(tensor.ptr()); | auto* tw = TensorWrapper::try_cast(tensor.ptr()); | ||||
auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0]; | auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0]; | ||||
@@ -1546,9 +1557,9 @@ void init_tensor(py::module m) { | |||||
}); | }); | ||||
m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool { | m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool { | ||||
SmallVector<ValueRef> values; | |||||
for (auto&& tensor : tensors) { | |||||
values.push_back(tensor.cast<TensorWrapper>().m_tensor->data()); | |||||
ValueRefList values(tensors.size()); | |||||
for (size_t i = 0; i < tensors.size(); ++i) { | |||||
values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data(); | |||||
} | } | ||||
auto outputs = imperative::apply(GetGradKey(), values); | auto outputs = imperative::apply(GetGradKey(), values); | ||||
if (outputs[0].is<GradKeyValue>()) { | if (outputs[0].is<GradKeyValue>()) { | ||||
@@ -1559,9 +1570,9 @@ void init_tensor(py::module m) { | |||||
}); | }); | ||||
m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object { | m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object { | ||||
SmallVector<ValueRef> values; | |||||
for (auto&& tensor : tensors) { | |||||
values.push_back(tensor.cast<TensorWrapper>().m_tensor->data()); | |||||
ValueRefList values(tensors.size()); | |||||
for (size_t i = 0; i < tensors.size(); ++i) { | |||||
values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data(); | |||||
} | } | ||||
auto outputs = imperative::apply(GetGradKey(), values); | auto outputs = imperative::apply(GetGradKey(), values); | ||||
if (auto* grad_key_val = outputs[0].as<GradKeyValue>()) { | if (auto* grad_key_val = outputs[0].as<GradKeyValue>()) { | ||||
@@ -1578,7 +1589,7 @@ void init_tensor(py::module m) { | |||||
mgb_assert(GradKeyWrapper::wrap_t::type().isinstance(py_key.ptr())); | mgb_assert(GradKeyWrapper::wrap_t::type().isinstance(py_key.ptr())); | ||||
auto* key = reinterpret_cast<GradKeyWrapper::wrap_t*>(py_key.ptr())->inst(); | auto* key = reinterpret_cast<GradKeyWrapper::wrap_t*>(py_key.ptr())->inst(); | ||||
GenericFunction generic_backward_fn = | GenericFunction generic_backward_fn = | ||||
[backward_fn](Span<ValueRef> output_grads) -> std::vector<ValueRef> { | |||||
[backward_fn](Span<ValueRef> output_grads) -> ValueRefList { | |||||
py::list output_grad_tws; | py::list output_grad_tws; | ||||
for (auto&& output_grad : output_grads) { | for (auto&& output_grad : output_grads) { | ||||
if (output_grad) { | if (output_grad) { | ||||
@@ -1589,23 +1600,25 @@ void init_tensor(py::module m) { | |||||
} | } | ||||
} | } | ||||
py::tuple input_grad_tws = backward_fn(*output_grad_tws); | py::tuple input_grad_tws = backward_fn(*output_grad_tws); | ||||
std::vector<ValueRef> input_grads; | |||||
for (auto&& input_grad_tw : input_grad_tws) { | |||||
ValueRefList input_grads(input_grad_tws.size()); | |||||
for (size_t i = 0; i < input_grad_tws.size(); ++i) { | |||||
auto input_grad_tw = input_grad_tws[i]; | |||||
if (!input_grad_tw.is_none()) { | if (!input_grad_tw.is_none()) { | ||||
input_grads.push_back( | |||||
py::cast<TensorWrapper>(input_grad_tw).m_tensor->data()); | |||||
input_grads[i] = | |||||
py::cast<TensorWrapper>(input_grad_tw).m_tensor->data(); | |||||
} else { | } else { | ||||
input_grads.push_back({}); | |||||
input_grads[i] = {}; | |||||
} | } | ||||
} | } | ||||
return input_grads; | return input_grads; | ||||
}; | }; | ||||
SmallVector<ValueRef> values; | |||||
for (auto&& input : inputs) { | |||||
values.push_back(input.cast<TensorWrapper>().m_tensor->data()); | |||||
ValueRefList values(inputs.size() + outputs.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data(); | |||||
} | } | ||||
for (auto&& output : outputs) { | |||||
values.push_back(output.cast<TensorWrapper>().m_tensor->data()); | |||||
for (size_t i = 0; i < outputs.size(); ++i) { | |||||
values[i + inputs.size()] = | |||||
outputs[i].cast<TensorWrapper>().m_tensor->data(); | |||||
} | } | ||||
auto wrapped_output_values = imperative::apply( | auto wrapped_output_values = imperative::apply( | ||||
SetGrad(key->m_key, generic_backward_fn, inputs.size()), values); | SetGrad(key->m_key, generic_backward_fn, inputs.size()), values); | ||||
@@ -39,7 +39,7 @@ namespace mgb::imperative::python { | |||||
extern interpreter::Interpreter::Channel* interpreter_for_py; | extern interpreter::Interpreter::Channel* interpreter_for_py; | ||||
extern PyTypeObject* py_tensor_type; | extern PyTypeObject* py_tensor_type; | ||||
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||||
struct Tensor : NonCopyableObj { | |||||
private: | private: | ||||
std::string m_name; | std::string m_name; | ||||
ValueRef m_data; | ValueRef m_data; | ||||
@@ -52,7 +52,7 @@ public: | |||||
~Tensor() = default; | ~Tensor() = default; | ||||
inline std::shared_ptr<Tensor> copy() { | inline std::shared_ptr<Tensor> copy() { | ||||
auto ret = std::make_shared<Tensor>(m_data.unwrap()); | |||||
auto ret = std::make_shared<Tensor>(m_data); | |||||
ret->m_name = m_name; | ret->m_name = m_name; | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -11,7 +11,15 @@ | |||||
#pragma once | #pragma once | ||||
#include <optional> | |||||
#include <string> | |||||
#include "pybind11/pybind11.h" | |||||
#include "megbrain/imperative/dispatch.h" | |||||
#include "megbrain/imperative/transformation.h" | #include "megbrain/imperative/transformation.h" | ||||
#include "megbrain/imperative/value.h" | |||||
#include "megbrain/utils/small_vector.h" | |||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
struct TransformationManager { | struct TransformationManager { | ||||
@@ -58,4 +66,14 @@ struct TransformationManager { | |||||
return sl_instance; | return sl_instance; | ||||
} | } | ||||
}; | }; | ||||
class PyValue final : public MixinValueImpl<PyValue, pybind11::object> { | |||||
public: | |||||
using MixinValueImpl::MixinValueImpl; | |||||
std::string to_string() const { | |||||
return pybind11::str((const pybind11::object&)*this).cast<std::string>(); | |||||
} | |||||
}; | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python |
@@ -45,7 +45,7 @@ CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout) | |||||
layout.is_contiguous() || layout.is_empty(), "layout should be contiguous"); | layout.is_contiguous() || layout.is_empty(), "layout should be contiguous"); | ||||
} | } | ||||
auto CreateTensor::parse(Span<ValueRef> inputs) -> Args { | |||||
auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args { | |||||
Args result; | Args result; | ||||
for (auto&& input : inputs) { | for (auto&& input : inputs) { | ||||
if (auto host_storage = input.as_ref<HostStorage>()) { | if (auto host_storage = input.as_ref<HostStorage>()) { | ||||
@@ -16,70 +16,67 @@ | |||||
#include "megbrain/imperative/utils/map.h" | #include "megbrain/imperative/utils/map.h" | ||||
namespace mgb { | namespace mgb { | ||||
void imperative_log_profile_begin(const char* message); | |||||
void imperative_log_profile(const char* message); | |||||
void imperative_log_profile_end(const char* message); | |||||
namespace imperative { | namespace imperative { | ||||
std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs) { | |||||
static bool log_dispatch = MGB_GETENV("MGE_LOG_OP_DISPATCH"); | |||||
bool enable_watch = ValueRef::any_watching(); | |||||
auto& context = Transformation::get_context(); | |||||
size_t& depth = context.next_transformation; | |||||
static const char tabs_storage[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t"; | |||||
const char* tabs = tabs_storage + sizeof(tabs_storage) / sizeof(char) - depth - 1; | |||||
bool log_current_dispatch = log_dispatch; | |||||
if (enable_watch) { | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
auto& input = inputs[i]; | |||||
if (input.watching()) { | |||||
log_current_dispatch = true; | |||||
mgb_log_debug("%sinput[%zu] is %s", tabs, i, input.to_string().c_str()); | |||||
debug::notify_event("apply"); | |||||
} | |||||
} | |||||
} | |||||
// entrance | |||||
std::vector<ValueRef> outputs; | |||||
if (depth >= context.transformations.size()) { | |||||
// fallback | |||||
if (log_current_dispatch) { | |||||
mgb_log_debug( | |||||
"%sfallback apply %s in %s", tabs, op.to_string().c_str(), | |||||
imperative::to_string(inputs).c_str()); | |||||
namespace { | |||||
MGB_NOINLINE void copy_outputs( | |||||
ForwardAllocator<ValueRef>& allocator, ValueRefList& outputs) { | |||||
size_t nr_outputs = outputs.size(); | |||||
if (mgb_likely(nr_outputs == 1)) { | |||||
ValueRef output_copy; | |||||
output_copy = outputs[0]; | |||||
allocator.clear(); | |||||
outputs = ValueRefList({output_copy}); | |||||
} else if (!outputs.empty()) { | |||||
SmallVector<ValueRef> outputs_copy(nr_outputs); | |||||
for (size_t i = 0; i < nr_outputs; ++i) { | |||||
outputs_copy[i] = outputs[i]; | |||||
} | } | ||||
outputs = op.fallback(inputs); | |||||
outputs.clear(); | |||||
allocator.clear(); | |||||
outputs = {outputs_copy.begin(), outputs_copy.end()}; | |||||
} else { | } else { | ||||
// dispatch to stack top | |||||
auto& transformation = *context.transformations[depth]; | |||||
++depth; | |||||
context.frames.push_back({op, inputs}); | |||||
CleanupGuard _{[&] { | |||||
context.frames.pop_back(); | |||||
--depth; | |||||
}}; | |||||
if (log_current_dispatch) { | |||||
mgb_log_debug( | |||||
"%s%s apply %s in %s", tabs, transformation.name().c_str(), | |||||
op.to_string().c_str(), imperative::to_string(inputs).c_str()); | |||||
} | |||||
outputs = transformation.apply_transformation(op, inputs); | |||||
allocator.clear(); | |||||
} | } | ||||
if (log_current_dispatch) { | |||||
mgb_log_debug("%sreturn %s", tabs, imperative::to_string(outputs).c_str()); | |||||
} | |||||
} // namespace | |||||
ValueRefList apply(const Operator& op, Span<ValueRef> inputs) { | |||||
auto& context = Transformation::get_context(); | |||||
size_t& depth = context.next_transformation; | |||||
bool top = depth == 0; | |||||
auto outputs = ([&] { | |||||
if (mgb_unlikely(depth >= context.transformations.size())) { | |||||
return op.fallback(inputs); | |||||
} else { | |||||
auto& transformation = *context.transformations[depth++]; | |||||
CleanupGuard _{[&] { --depth; }}; | |||||
return transformation.apply_transformation(op, inputs); | |||||
} | |||||
})(); | |||||
if (mgb_unlikely(top)) { | |||||
copy_outputs(context.allocator, outputs); | |||||
} | } | ||||
return outputs; | return outputs; | ||||
} | } | ||||
std::vector<ValueRef> apply(const OpDef& def, Span<ValueRef> inputs) { | |||||
ValueRefList apply(const OpDef& def, Span<ValueRef> inputs) { | |||||
return imperative::apply(ApplyOp{def}, inputs); | return imperative::apply(ApplyOp{def}, inputs); | ||||
} | } | ||||
std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs) { | |||||
ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs) { | |||||
SmallVector<ValueRef> inputs_storage; | SmallVector<ValueRef> inputs_storage; | ||||
for (size_t i = 0; i < inputs.size(); ++i) { | for (size_t i = 0; i < inputs.size(); ++i) { | ||||
inputs_storage.push_back(inputs[i]); | inputs_storage.push_back(inputs[i]); | ||||
} | } | ||||
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<ValueRef> inputs, | auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<ValueRef> inputs, | ||||
size_t) { | size_t) { | ||||
auto outputs = imperative::apply(ApplyOp(*op), inputs); | |||||
auto outputs = imperative::apply(*op, inputs); | |||||
return SmallVector<ValueRef>(outputs.begin(), outputs.end()); | return SmallVector<ValueRef>(outputs.begin(), outputs.end()); | ||||
}; | }; | ||||
auto make_const = [](TensorPtr constant) -> ValueRef { | auto make_const = [](TensorPtr constant) -> ValueRef { | ||||
@@ -101,7 +98,7 @@ std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs) { | |||||
DeviceStorage::make(device_value.storage()))[0]; | DeviceStorage::make(device_value.storage()))[0]; | ||||
}; | }; | ||||
auto outputs = graph.apply(inputs_storage, apply_functor, make_const); | auto outputs = graph.apply(inputs_storage, apply_functor, make_const); | ||||
return {outputs.begin(), outputs.end()}; | |||||
return ValueRefList{outputs.begin(), outputs.end()}; | |||||
} | } | ||||
} // namespace imperative | } // namespace imperative | ||||
@@ -126,7 +126,7 @@ public: | |||||
m_frames[m_frames.size() - 1 - i] = {node, node->version()}; | m_frames[m_frames.size() - 1 - i] = {node, node->version()}; | ||||
node = node->parent(); | node = node->parent(); | ||||
} | } | ||||
mgb_assert(node->is_root(), ""); | |||||
mgb_assert(node->is_root()); | |||||
} | } | ||||
Trace() = default; | Trace() = default; | ||||
std::string to_string() const { | std::string to_string() const { | ||||
@@ -3,7 +3,7 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
std::vector<ValueRef> Operator::fallback(Span<ValueRef> inputs) const { | |||||
ValueRefList Operator::fallback(Span<ValueRef> inputs) const { | |||||
mgb_throw(MegBrainError, "no fallback implementation for %s", to_string().c_str()); | mgb_throw(MegBrainError, "no fallback implementation for %s", to_string().c_str()); | ||||
} | } | ||||
@@ -99,19 +99,22 @@ Tensor::Tensor( | |||||
Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) { | Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) { | ||||
constexpr int size_threshold = TensorShape::MAX_NDIM; | constexpr int size_threshold = TensorShape::MAX_NDIM; | ||||
if (hv.layout().total_nr_elems() <= size_threshold) { | |||||
size_t nr_elems = hv.layout().total_nr_elems(); | |||||
if (nr_elems <= size_threshold) { | |||||
m_value = hv; | m_value = hv; | ||||
} | } | ||||
MGB_RECORD_EVENT( | |||||
profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(), | |||||
dev_tensor().raw_ptr()); | |||||
dev_tensor().copy_from_fixlayout(hv); | |||||
// even though hv is saved in m_value, Tensor itself could be | |||||
// released before copy completes | |||||
MGB_RECORD_EVENT( | |||||
profiler::HostToDeviceFinishEvent, hv.layout(), hv.comp_node(), | |||||
hv.raw_ptr(), dev_tensor().raw_ptr()); | |||||
AsyncReleaser::inst()->add(hv); | |||||
if (nr_elems) { | |||||
MGB_RECORD_EVENT( | |||||
profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(), | |||||
dev_tensor().raw_ptr()); | |||||
dev_tensor().copy_from_fixlayout(hv); | |||||
// even though hv is saved in m_value, Tensor itself could be | |||||
// released before copy completes | |||||
MGB_RECORD_EVENT( | |||||
profiler::HostToDeviceFinishEvent, hv.layout(), hv.comp_node(), | |||||
hv.raw_ptr(), dev_tensor().raw_ptr()); | |||||
AsyncReleaser::inst()->add(hv); | |||||
} | |||||
} | } | ||||
Tensor::Tensor(const DeviceTensorND& dv, const HostTensorND& hv) { | Tensor::Tensor(const DeviceTensorND& dv, const HostTensorND& hv) { | ||||
@@ -310,7 +310,8 @@ struct ChromeTimelineEventVisitor : EventVisitor<ChromeTimelineEventVisitor> { | |||||
} else if constexpr (std::is_same_v<TEvent, TensorGetPropEvent>) { | } else if constexpr (std::is_same_v<TEvent, TensorGetPropEvent>) { | ||||
new_host_event("TensorGetProp", 'X') | new_host_event("TensorGetProp", 'X') | ||||
.dur(0) | .dur(0) | ||||
.args(current_tensor->detail(current->time)); | |||||
.args(current_tensor->detail(current->time)) | |||||
.arg("kind", imperative::to_string(event.prop)); | |||||
} else if constexpr (std::is_same_v<TEvent, TensorWaitPropEvent>) { | } else if constexpr (std::is_same_v<TEvent, TensorWaitPropEvent>) { | ||||
new_host_event("TensorWaitProp", 'B'); | new_host_event("TensorWaitProp", 'B'); | ||||
} else if constexpr (std::is_same_v<TEvent, TensorWaitPropFinishEvent>) { | } else if constexpr (std::is_same_v<TEvent, TensorWaitPropFinishEvent>) { | ||||
@@ -15,71 +15,109 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
std::vector<ValueRef> InterpreterTransformation::apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) { | |||||
if (auto* op_val = op.as<ApplyOp>()) { | |||||
if (op_val->op().same_type<FastpathCopy>()) { | |||||
return {inputs[0]}; | |||||
} | |||||
SmallVector<Handle> input_handles; | |||||
SmallVector<Handle> output_handles; | |||||
CleanupGuard _{[&] { | |||||
for (auto handle : output_handles) { | |||||
if (handle) { | |||||
m_channel->del(handle); | |||||
} | |||||
DTypeValue::ref_t InterpreterInfo::dtype() const { | |||||
if (!m_dtype) { | |||||
m_dtype = DTypeValue::make(handle()->channel()->get_dtype(handle()->handle())); | |||||
} | |||||
return m_dtype; | |||||
} | |||||
CompNodeValue::ref_t InterpreterInfo::comp_node() const { | |||||
if (!m_comp_node) { | |||||
m_comp_node = CompNodeValue::make( | |||||
handle()->channel()->get_device(handle()->handle())); | |||||
} | |||||
return m_comp_node; | |||||
} | |||||
ShapeValue::ref_t InterpreterInfo::shape() const { | |||||
if (!m_shape) { | |||||
m_shape = ShapeValue::make( | |||||
ValueShape::from(handle()->channel()->get_shape(handle()->handle()))); | |||||
} | |||||
return m_shape; | |||||
} | |||||
ValueRefList InterpreterTransformation::apply_op( | |||||
const ApplyOp& apply_op, Span<ValueRef> inputs) { | |||||
if (apply_op.op().same_type<FastpathCopy>()) { | |||||
return {inputs[0]}; | |||||
} | |||||
SmallVector<Handle> input_handles; | |||||
SmallVector<Handle> output_handles; | |||||
CleanupGuard _{[&] { | |||||
for (auto handle : output_handles) { | |||||
if (handle) { | |||||
m_channel->del(handle); | |||||
} | } | ||||
}}; | |||||
for (auto input : inputs) { | |||||
input_handles.push_back(*input.cast<InterpreterValue>().handle()); | |||||
} | |||||
output_handles = | |||||
m_channel->apply_op(op_val->op().shared_from_this(), input_handles); | |||||
std::vector<ValueRef> outputs; | |||||
for (auto& handle : output_handles) { | |||||
outputs.push_back(InterpreterValue::make(share_handle(handle))); | |||||
handle = nullptr; | |||||
} | } | ||||
return outputs; | |||||
}}; | |||||
for (auto input : inputs) { | |||||
input_handles.push_back(input.cast<InterpreterValue>().handle()->handle()); | |||||
} | |||||
output_handles = | |||||
m_channel->apply_op(apply_op.op().shared_from_this(), input_handles); | |||||
ValueRefList outputs(output_handles.size()); | |||||
for (size_t i = 0; i < output_handles.size(); ++i) { | |||||
outputs[i] = InterpreterValue::make(share_handle(output_handles[i])); | |||||
output_handles[i] = nullptr; | |||||
} | |||||
return outputs; | |||||
} | |||||
ValueRefList InterpreterTransformation::apply_get_attr( | |||||
const GetAttr& get_attr, Span<ValueRef> inputs) { | |||||
auto& input = inputs.item().cast<InterpreterValue>(); | |||||
ValueRef output; | |||||
switch (get_attr.attr()) { | |||||
case GetAttr::DType: | |||||
output = input.dtype(); | |||||
break; | |||||
case GetAttr::Shape: | |||||
output = input.shape(); | |||||
break; | |||||
case GetAttr::Device: | |||||
output = input.comp_node(); | |||||
break; | |||||
case GetAttr::Value: | |||||
output = HostValue::make(m_channel->get_value(input.handle()->handle())); | |||||
break; | |||||
case GetAttr::Data: | |||||
output = DeviceValue::make( | |||||
m_channel->get_dev_tensor(input.handle()->handle())); | |||||
break; | |||||
default: | |||||
mgb_throw( | |||||
MegBrainError, "Interpreter: malformed GetAttr: %s", | |||||
get_attr.to_string().c_str()); | |||||
} | |||||
return {output}; | |||||
} | |||||
ValueRefList InterpreterTransformation::apply_create_tensor( | |||||
const CreateTensor& create_tensor, Span<ValueRef> inputs) { | |||||
auto args = create_tensor.parse(inputs); | |||||
if (!args.device) { | |||||
// implies H2D | |||||
mgb_assert(args.host, "neither host and device value is valid"); | |||||
return {InterpreterValue::make(share_handle( | |||||
m_channel->put(*args.host, args.kind == CreateTensor::Unique)))}; | |||||
} else { | |||||
return {InterpreterValue::make(share_handle(m_channel->put( | |||||
*args.device, args.host ? *args.host : HostTensorND())))}; | |||||
} | |||||
} | |||||
ValueRefList InterpreterTransformation::apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) { | |||||
if (auto* op_val = op.as<ApplyOp>()) { | |||||
return apply_op(*op_val, inputs); | |||||
} else if (auto* get_attr = op.as<GetAttr>()) { | } else if (auto* get_attr = op.as<GetAttr>()) { | ||||
Handle handle = *inputs[0].cast<InterpreterValue>().handle(); | |||||
ValueRef output; | |||||
switch (get_attr->attr()) { | |||||
case GetAttr::DType: | |||||
output = DTypeValue::make(m_channel->get_dtype(handle)); | |||||
break; | |||||
case GetAttr::Shape: | |||||
output = ShapeValue::make( | |||||
ValueShape::from(m_channel->get_shape(handle))); | |||||
break; | |||||
case GetAttr::Device: | |||||
output = CompNodeValue::make(m_channel->get_device(handle)); | |||||
break; | |||||
case GetAttr::Value: | |||||
output = HostValue::make(m_channel->get_value(handle)); | |||||
break; | |||||
case GetAttr::Data: | |||||
output = DeviceValue::make(m_channel->get_dev_tensor(handle)); | |||||
break; | |||||
default: | |||||
mgb_throw( | |||||
MegBrainError, "Interpreter: malformed GetAttr: %s", | |||||
op.to_string().c_str()); | |||||
} | |||||
return {output}; | |||||
return apply_get_attr(*get_attr, inputs); | |||||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | } else if (auto* create_tensor = op.as<CreateTensor>()) { | ||||
auto args = create_tensor->parse(inputs); | |||||
if (!args.device) { | |||||
// implies H2D | |||||
mgb_assert(args.host, "neither host and device value is valid"); | |||||
return {InterpreterValue::make(share_handle( | |||||
m_channel->put(*args.host, args.kind == CreateTensor::Unique)))}; | |||||
} else { | |||||
return {InterpreterValue::make(share_handle(m_channel->put( | |||||
*args.device, args.host ? *args.host : HostTensorND())))}; | |||||
} | |||||
return apply_create_tensor(*create_tensor, inputs); | |||||
} else if (auto* dtr_command = op.as<DTRCommand>()) { | } else if (auto* dtr_command = op.as<DTRCommand>()) { | ||||
auto handle = *inputs[0].cast<InterpreterValue>().handle(); | |||||
auto handle = inputs[0].cast<InterpreterValue>().handle()->handle(); | |||||
switch (dtr_command->kind()) { | switch (dtr_command->kind()) { | ||||
case DTRCommand::Drop: | case DTRCommand::Drop: | ||||
m_channel->drop(handle); | m_channel->drop(handle); | ||||
@@ -64,12 +64,13 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( | |||||
size_t count = std::count_if( | size_t count = std::count_if( | ||||
save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); | save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); | ||||
if (!backward_graph->precomp.empty()) { | if (!backward_graph->precomp.empty()) { | ||||
SmallVector<ValueRef> inputs_and_outputs; | |||||
ValueRefList inputs_and_outputs(inputs.size() + outputs.size()); | |||||
auto it = inputs_and_outputs.begin(); | |||||
for (auto&& input : inputs) { | for (auto&& input : inputs) { | ||||
inputs_and_outputs.push_back(input); | |||||
*it++ = input; | |||||
} | } | ||||
for (auto&& output : outputs) { | for (auto&& output : outputs) { | ||||
inputs_and_outputs.push_back(output); | |||||
*it++ = output; | |||||
} | } | ||||
auto precomp = imperative::apply(backward_graph->precomp, inputs_and_outputs); | auto precomp = imperative::apply(backward_graph->precomp, inputs_and_outputs); | ||||
closure.reserve(precomp.size() + count); | closure.reserve(precomp.size() + count); | ||||
@@ -89,7 +90,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( | |||||
} | } | ||||
} | } | ||||
void BackwardGraphWithClosure::operator()( | void BackwardGraphWithClosure::operator()( | ||||
std::vector<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) { | |||||
ValueRefList grads, std::function<void(size_t, ValueRef)> receiver) { | |||||
ValueRef args[closure.size() + grads.size()]; | ValueRef args[closure.size() + grads.size()]; | ||||
size_t nargs = 0; | size_t nargs = 0; | ||||
for (auto&& value : closure) { | for (auto&& value : closure) { | ||||
@@ -120,7 +121,7 @@ void BackwardGraphWithClosure::operator()( | |||||
} | } | ||||
void CustomBackward::operator()( | void CustomBackward::operator()( | ||||
std::vector<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) { | |||||
ValueRefList grads, std::function<void(size_t, ValueRef)> receiver) { | |||||
size_t nargs = grads.size(); | size_t nargs = grads.size(); | ||||
ValueRef args[nargs]; | ValueRef args[nargs]; | ||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
@@ -201,9 +202,10 @@ void GradKey::backward() { | |||||
mgb_throw(AssertionError, "invalid backward"); | mgb_throw(AssertionError, "invalid backward"); | ||||
} else { | } else { | ||||
mgb_assert(grad_fn->m_slots.size() > 0); | mgb_assert(grad_fn->m_slots.size() > 0); | ||||
std::vector<ValueRef> grads; | |||||
ValueRefList grads (grad_fn->m_slots.size()); | |||||
auto iter = grads.begin(); | |||||
for (auto&& slot : grad_fn->m_slots) { | for (auto&& slot : grad_fn->m_slots) { | ||||
grads.push_back(slot.m_grad); | |||||
*iter++ = slot.m_grad; | |||||
} | } | ||||
backward(grads, grad_receiver); | backward(grads, grad_receiver); | ||||
} | } | ||||
@@ -254,21 +256,28 @@ void GradKey::freeze() { | |||||
m_frozen = true; | m_frozen = true; | ||||
} | } | ||||
std::vector<ValueRef> GradTransformation::apply_transformation( | |||||
ValueRefList GradTransformation::apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) { | const Operator& op, Span<ValueRef> inputs) { | ||||
auto unwrap_inputs = [this](Span<ValueRef> inputs) -> SmallVector<ValueRef> { | |||||
SmallVector<ValueRef> unwrapped_inputs; | |||||
for (auto&& input : inputs) { | |||||
if (auto grad_value = as_grad_value(input)) { | |||||
unwrapped_inputs.push_back(grad_value->m_value); | |||||
auto fallback = [&] { | |||||
ValueRefList unwrapped_inputs(inputs.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (auto grad_value = as_grad_value(inputs[i])) { | |||||
unwrapped_inputs[i] = grad_value->m_value; | |||||
} else { | } else { | ||||
unwrapped_inputs.push_back(input); | |||||
unwrapped_inputs[i] = inputs[i]; | |||||
} | } | ||||
} | } | ||||
return unwrapped_inputs; | |||||
return imperative::apply(op, unwrapped_inputs); | |||||
}; | }; | ||||
if (auto* get_attr = op.as<GetAttr>()) { | |||||
if (auto grad_value = as_grad_value(inputs.item())) { | |||||
return imperative::apply(op, grad_value->m_value); | |||||
} else { | |||||
return imperative::apply(op, inputs); | |||||
} | |||||
} | |||||
if (m_suppressed) { | if (m_suppressed) { | ||||
return imperative::apply(op, unwrap_inputs(inputs)); | |||||
return fallback(); | |||||
} | } | ||||
if (auto* op_val = op.as<ApplyOp>()) { | if (auto* op_val = op.as<ApplyOp>()) { | ||||
size_t nr_require_grad = 0; | size_t nr_require_grad = 0; | ||||
@@ -284,20 +293,21 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||||
if (nr_require_grad == 0) { | if (nr_require_grad == 0) { | ||||
return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
} | } | ||||
SmallVector<ValueRef> captured_inputs; | |||||
SmallVector<bool> inputs_require_grad; | |||||
ValueRefList captured_inputs(inputs.size()); | |||||
SmallVector<bool> inputs_require_grad(inputs.size()); | |||||
// capture value so that trace could assume input as same | // capture value so that trace could assume input as same | ||||
auto capture_value = [](ValueRef value) { | auto capture_value = [](ValueRef value) { | ||||
// TODO: fastpath copy shouldn't be an OpDef | // TODO: fastpath copy shouldn't be an OpDef | ||||
return imperative::apply(ApplyOp(*FastpathCopy::make()), {&value, 1})[0]; | return imperative::apply(ApplyOp(*FastpathCopy::make()), {&value, 1})[0]; | ||||
}; | }; | ||||
for (auto& input : inputs) { | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
auto& input = inputs[i]; | |||||
if (auto grad_value = as_grad_value(input)) { | if (auto grad_value = as_grad_value(input)) { | ||||
captured_inputs.push_back(capture_value(grad_value->m_value)); | |||||
inputs_require_grad.push_back(true); | |||||
captured_inputs[i] = capture_value(grad_value->m_value); | |||||
inputs_require_grad[i] = true; | |||||
} else { | } else { | ||||
captured_inputs.push_back(capture_value(input)); | |||||
inputs_require_grad.push_back(false); | |||||
captured_inputs[i] = capture_value(input); | |||||
inputs_require_grad[i] = false; | |||||
} | } | ||||
} | } | ||||
decltype(std::declval<GradFn>().m_backward) backward_storage; | decltype(std::declval<GradFn>().m_backward) backward_storage; | ||||
@@ -373,9 +383,11 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||||
mgb_assert(!grad_fn->m_slots.empty()); | mgb_assert(!grad_fn->m_slots.empty()); | ||||
m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()}); | m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()}); | ||||
return outputs; | return outputs; | ||||
} else if (op.is<CreateTensor>()) { | |||||
return imperative::apply(op, inputs); | |||||
} else if (auto* attach_grad = op.as<AttachGrad>()) { | } else if (auto* attach_grad = op.as<AttachGrad>()) { | ||||
if (!has_key(attach_grad->key())) { | if (!has_key(attach_grad->key())) { | ||||
return imperative::apply(op, unwrap_inputs(inputs)); | |||||
return fallback(); | |||||
} | } | ||||
auto tensor = inputs[0]; | auto tensor = inputs[0]; | ||||
GenericFunction callback = (GenericFunction&)inputs[1].cast<FunctionValue>(); | GenericFunction callback = (GenericFunction&)inputs[1].cast<FunctionValue>(); | ||||
@@ -386,7 +398,7 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||||
return {record_grad(output)}; | return {record_grad(output)}; | ||||
} else if (auto* grad_backward = op.as<GradBackward>()) { | } else if (auto* grad_backward = op.as<GradBackward>()) { | ||||
if (!has_key(grad_backward->key())) { | if (!has_key(grad_backward->key())) { | ||||
return imperative::apply(op, unwrap_inputs(inputs)); | |||||
return fallback(); | |||||
} | } | ||||
size_t nr_grads = inputs.size() / 2; | size_t nr_grads = inputs.size() / 2; | ||||
mgb_assert(nr_grads * 2 == inputs.size()); | mgb_assert(nr_grads * 2 == inputs.size()); | ||||
@@ -416,7 +428,7 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||||
backward.m_output_attrs = | backward.m_output_attrs = | ||||
SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true}); | SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true}); | ||||
backward.m_backward = set_grad->grad_fn(); | backward.m_backward = set_grad->grad_fn(); | ||||
std::vector<ValueRef> outputs; | |||||
ValueRefList outputs(nr_outputs); | |||||
grad_fn->m_key = m_key; | grad_fn->m_key = m_key; | ||||
grad_fn->m_slots.resize(nr_outputs); | grad_fn->m_slots.resize(nr_outputs); | ||||
grad_fn->m_dests.reserve(nr_inputs); | grad_fn->m_dests.reserve(nr_inputs); | ||||
@@ -439,13 +451,13 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||||
} else { | } else { | ||||
grad_value = GradValue::make(output, m_key, GradSlotPtr(grad_fn, i)); | grad_value = GradValue::make(output, m_key, GradSlotPtr(grad_fn, i)); | ||||
} | } | ||||
outputs.push_back(record_grad(grad_value)); | |||||
outputs[i] = record_grad(grad_value); | |||||
} | } | ||||
m_key->m_tape.push_back({grad_fn, nullptr}); | m_key->m_tape.push_back({grad_fn, nullptr}); | ||||
return outputs; | return outputs; | ||||
} else if (auto* gbc = op.as<GetBackwardColsure>()) { | } else if (auto* gbc = op.as<GetBackwardColsure>()) { | ||||
if (gbc->key() != m_key) { | if (gbc->key() != m_key) { | ||||
return imperative::apply(op, unwrap_inputs(inputs)); | |||||
return fallback(); | |||||
} | } | ||||
return {FunctionValue::make(make_backward_closure(inputs))}; | return {FunctionValue::make(make_backward_closure(inputs))}; | ||||
} else if (op.is<DetachGrad>()) { | } else if (op.is<DetachGrad>()) { | ||||
@@ -471,21 +483,8 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||||
} else { | } else { | ||||
return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
} | } | ||||
} else if (op.is<CreateTensor>()) { | |||||
return imperative::apply(op, inputs); | |||||
} else { | } else { | ||||
SmallVector<ValueRef> unwrapped_inputs; | |||||
for (auto&& input : inputs) { | |||||
if (auto grad_value = as_grad_value(input)) { | |||||
unwrapped_inputs.push_back(grad_value->m_value); | |||||
} else { | |||||
unwrapped_inputs.push_back(input); | |||||
} | |||||
} | |||||
auto outputs = imperative::apply( | |||||
op, {unwrapped_inputs.data(), unwrapped_inputs.size()}); | |||||
mgb_assert(op.kind() == Operator::GetAttrLike || outputs.empty()); | |||||
return outputs; | |||||
return fallback(); | |||||
} | } | ||||
} | } | ||||
@@ -500,8 +499,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) { | |||||
y_slots.emplace_back(); | y_slots.emplace_back(); | ||||
} | } | ||||
} | } | ||||
GenericFunction closure = [grad_key, | |||||
y_slots](Span<ValueRef> dys) -> std::vector<ValueRef> { | |||||
GenericFunction closure = [grad_key, y_slots](Span<ValueRef> dys) -> ValueRefList { | |||||
size_t nr_grads = y_slots.size(); | size_t nr_grads = y_slots.size(); | ||||
mgb_assert(dys.size() == nr_grads); | mgb_assert(dys.size() == nr_grads); | ||||
for (size_t i = 0; i < nr_grads; ++i) { | for (size_t i = 0; i < nr_grads; ++i) { | ||||
@@ -21,7 +21,7 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
std::vector<ValueRef> LazyEvalTransformation::apply_transformation( | |||||
ValueRefList LazyEvalTransformation::apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) { | const Operator& op, Span<ValueRef> inputs) { | ||||
if (auto* op_val = op.as<ApplyOp>()) { | if (auto* op_val = op.as<ApplyOp>()) { | ||||
static std::unordered_set<Typeinfo*> mm_io_ops = { | static std::unordered_set<Typeinfo*> mm_io_ops = { | ||||
@@ -59,9 +59,9 @@ std::vector<ValueRef> LazyEvalTransformation::apply_transformation( | |||||
mgb_assert(!output_nodes.empty()); | mgb_assert(!output_nodes.empty()); | ||||
m_io_link = SymbolVar(output_nodes[0]); | m_io_link = SymbolVar(output_nodes[0]); | ||||
} | } | ||||
std::vector<ValueRef> outputs; | |||||
for (auto&& output_node : output_nodes) { | |||||
outputs.push_back(record_var(output_node)); | |||||
ValueRefList outputs(output_nodes.size()); | |||||
for (size_t i = 0; i < output_nodes.size(); ++i) { | |||||
outputs[i] = record_var(output_nodes[i]); | |||||
} | } | ||||
return outputs; | return outputs; | ||||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | } else if (auto* create_tensor = op.as<CreateTensor>()) { | ||||
@@ -19,26 +19,8 @@ namespace imperative { | |||||
namespace { | namespace { | ||||
using ScalarRule = std::function<std::vector<ValueRef>(const OpDef&, Span<ValueRef>)>; | |||||
static std::unordered_map< | |||||
Typeinfo*, std::function<std::vector<ValueRef>(const OpDef&, Span<ValueRef>)>> | |||||
scalar_rules; | |||||
ValueRef unwrap_input(ValueRef input) { | |||||
if (auto scalar_input = input.as_ref<ScalarValue>()) { | |||||
return scalar_input->value(); | |||||
} else { | |||||
return input; | |||||
} | |||||
} | |||||
std::vector<ValueRef> unwrap_inputs(Span<ValueRef> inputs) { | |||||
std::vector<ValueRef> unwrapped_inputs; | |||||
for (auto&& input : inputs) { | |||||
unwrapped_inputs.push_back(unwrap_input(input)); | |||||
} | |||||
return unwrapped_inputs; | |||||
} | |||||
using ScalarRule = ValueRefList (*)(const OpDef&, Span<ValueRef>, Span<bool>); | |||||
static std::unordered_map<Typeinfo*, ScalarRule> scalar_rules; | |||||
ValueRef make_scalar_shape(CompNode device) { | ValueRef make_scalar_shape(CompNode device) { | ||||
HostTensorND scalar_shape(device, {1}, dtype::Int32()); | HostTensorND scalar_shape(device, {1}, dtype::Int32()); | ||||
@@ -49,9 +31,6 @@ ValueRef make_scalar_shape(CompNode device) { | |||||
} | } | ||||
bool is_scalar_shape(ValueRef shape) { | bool is_scalar_shape(ValueRef shape) { | ||||
if (shape.is<ScalarValue>()) { | |||||
return false; | |||||
} | |||||
// may have performance issue | // may have performance issue | ||||
auto shape_of_shape = shape.shape(); | auto shape_of_shape = shape.shape(); | ||||
if (!shape_of_shape) { | if (!shape_of_shape) { | ||||
@@ -61,74 +40,65 @@ bool is_scalar_shape(ValueRef shape) { | |||||
return *shape_of_shape == ValueShape{0}; | return *shape_of_shape == ValueShape{0}; | ||||
} | } | ||||
template <typename T> | |||||
void register_scalar_rule(std::vector<ValueRef> (*rule)(const T&, Span<ValueRef>)) { | |||||
scalar_rules[T::typeinfo()] = [rule](const OpDef& def, Span<ValueRef> inputs) { | |||||
return (*rule)(def.cast_final_safe<T>(), inputs); | |||||
template <typename T, ValueRefList (*rule)(const T&, Span<ValueRef>, Span<bool>)> | |||||
void register_scalar_rule() { | |||||
scalar_rules[T::typeinfo()] = [](const OpDef& def, Span<ValueRef> inputs, | |||||
Span<bool> inputs_mask) { | |||||
return (*rule)(def.cast_final_safe<T>(), inputs, inputs_mask); | |||||
}; | }; | ||||
} | } | ||||
std::vector<ValueRef> elemwise_rule(const Elemwise& elem, Span<ValueRef> inputs) { | |||||
template <typename TOpDef, size_t nr_inputs> | |||||
ValueRefList elemwise_rule( | |||||
const TOpDef& op_def, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||||
if constexpr (nr_inputs != 0) { | |||||
mgb_assert(inputs.size() == inputs.size(), "inputs size mismatch"); | |||||
} | |||||
bool all_scalar = true; | bool all_scalar = true; | ||||
for (auto&& input : inputs) { | |||||
if (!input.is<ScalarValue>()) { | |||||
for (auto&& input_mask : inputs_mask) { | |||||
if (!input_mask) { | |||||
all_scalar = false; | all_scalar = false; | ||||
break; | |||||
} | } | ||||
} | } | ||||
auto output = imperative::apply(elem, unwrap_inputs(inputs))[0]; | |||||
auto outputs = imperative::apply(op_def, inputs); | |||||
if (all_scalar) { | if (all_scalar) { | ||||
return {ScalarValue::make(output)}; | |||||
} else { | |||||
return {output}; | |||||
outputs[0] = ScalarValue::make(outputs[0]); | |||||
} | } | ||||
return outputs; | |||||
} | } | ||||
std::vector<ValueRef> remove_axis_rule( | |||||
const RemoveAxis& remove_axis, Span<ValueRef> inputs) { | |||||
mgb_assert(inputs.size() == 1); | |||||
mgb_assert(!inputs[0].is<ScalarValue>()); | |||||
auto output = imperative::apply(remove_axis, inputs)[0]; | |||||
bool is_scalar = inputs[0].shape()->ndim == remove_axis.axis.size(); | |||||
ValueRefList remove_axis_rule( | |||||
const RemoveAxis& remove_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||||
mgb_assert(!inputs_mask.item()); | |||||
bool is_scalar = inputs.item().shape()->ndim == remove_axis.axis.size(); | |||||
if (is_scalar && remove_axis.axis.size() == 1) { | |||||
return {ScalarValue::make(inputs.item())}; | |||||
} | |||||
auto outputs = imperative::apply(remove_axis, inputs); | |||||
if (is_scalar) { | if (is_scalar) { | ||||
return {ScalarValue::make(output)}; | |||||
} else { | |||||
return {output}; | |||||
outputs[0] = ScalarValue::make(outputs[0]); | |||||
} | } | ||||
return outputs; | |||||
} | } | ||||
std::vector<ValueRef> reduce_rule(const Reduce& reduce, Span<ValueRef> inputs) { | |||||
ValueRefList reduce_rule( | |||||
const Reduce& reduce, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||||
if (inputs.size() == 1) { | if (inputs.size() == 1) { | ||||
return imperative::apply(reduce, unwrap_inputs(inputs)); | |||||
return imperative::apply(reduce, inputs); | |||||
} | } | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
bool is_scalar = is_scalar_shape(inputs[1]); | bool is_scalar = is_scalar_shape(inputs[1]); | ||||
if (is_scalar) { | if (is_scalar) { | ||||
auto unwrapped_input = unwrap_input(inputs[0]); | |||||
CompNode device = *unwrapped_input.device(); | |||||
return {ScalarValue::make(imperative::apply( | |||||
reduce, unwrapped_input, make_scalar_shape(device))[0])}; | |||||
} | |||||
auto output = imperative::apply(reduce, unwrap_inputs(inputs))[0]; | |||||
if (is_scalar) { | |||||
return {ScalarValue::make(output)}; | |||||
} else { | |||||
return {output}; | |||||
} | |||||
} | |||||
std::vector<ValueRef> typecvt_rule(const TypeCvt& typecvt, Span<ValueRef> inputs) { | |||||
mgb_assert(inputs.size() == 1); | |||||
if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { | |||||
CompNode device = *inputs[0].device(); | |||||
return {ScalarValue::make( | return {ScalarValue::make( | ||||
imperative::apply(typecvt, scalar_input->value())[0])}; | |||||
} else { | |||||
return imperative::apply(typecvt, inputs); | |||||
imperative::apply(reduce, inputs[0], make_scalar_shape(device))[0])}; | |||||
} | } | ||||
return imperative::apply(reduce, inputs); | |||||
} | } | ||||
std::vector<ValueRef> collective_comm_rule( | |||||
const CollectiveComm& collective_comm, Span<ValueRef> inputs) { | |||||
ValueRefList collective_comm_rule( | |||||
const CollectiveComm& collective_comm, Span<ValueRef> inputs, | |||||
Span<bool> inputs_mask) { | |||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
static std::unordered_set<CollectiveComm::Mode> modes = { | static std::unordered_set<CollectiveComm::Mode> modes = { | ||||
CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN, | CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN, | ||||
@@ -138,17 +108,17 @@ std::vector<ValueRef> collective_comm_rule( | |||||
if (modes.count(collective_comm.mode) == 0) { | if (modes.count(collective_comm.mode) == 0) { | ||||
return imperative::apply(collective_comm, inputs); | return imperative::apply(collective_comm, inputs); | ||||
} | } | ||||
if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { | |||||
return {ScalarValue::make( | |||||
imperative::apply(collective_comm, scalar_input->value())[0])}; | |||||
if (inputs_mask.item()) { | |||||
return {ScalarValue::make(imperative::apply(collective_comm, inputs[0])[0])}; | |||||
} else { | } else { | ||||
return imperative::apply(collective_comm, inputs); | return imperative::apply(collective_comm, inputs); | ||||
} | } | ||||
} | } | ||||
std::vector<ValueRef> param_pack_split_rule( | |||||
const ParamPackSplit& param_pack_split, Span<ValueRef> inputs) { | |||||
auto outputs = imperative::apply(param_pack_split, unwrap_inputs(inputs)); | |||||
ValueRefList param_pack_split_rule( | |||||
const ParamPackSplit& param_pack_split, Span<ValueRef> inputs, | |||||
Span<bool> inputs_mask) { | |||||
auto outputs = imperative::apply(param_pack_split, inputs); | |||||
size_t nr_outputs = outputs.size(); | size_t nr_outputs = outputs.size(); | ||||
mgb_assert(nr_outputs == param_pack_split.shapes.size()); | mgb_assert(nr_outputs == param_pack_split.shapes.size()); | ||||
for (size_t i = 0; i < nr_outputs; ++i) { | for (size_t i = 0; i < nr_outputs; ++i) { | ||||
@@ -159,29 +129,28 @@ std::vector<ValueRef> param_pack_split_rule( | |||||
return outputs; | return outputs; | ||||
} | } | ||||
std::vector<ValueRef> dot_rule(const Dot& dot, Span<ValueRef> inputs) { | |||||
return {ScalarValue::make(imperative::apply(dot, unwrap_inputs(inputs))[0])}; | |||||
ValueRefList dot_rule(const Dot& dot, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||||
return {ScalarValue::make(imperative::apply(dot, inputs)[0])}; | |||||
} | } | ||||
std::vector<ValueRef> add_axis_rule(const AddAxis& add_axis, Span<ValueRef> inputs) { | |||||
ValueRefList add_axis_rule( | |||||
const AddAxis& add_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { | |||||
if (inputs_mask.item()) { | |||||
mgb_assert(add_axis.axis[0] == 0); | mgb_assert(add_axis.axis[0] == 0); | ||||
if (add_axis.axis.size() == 1) { | if (add_axis.axis.size() == 1) { | ||||
return {scalar_input->value()}; | |||||
return {inputs[0]}; | |||||
} else { | } else { | ||||
std::vector<int32_t> axis(add_axis.axis.begin() + 1, add_axis.axis.end()); | std::vector<int32_t> axis(add_axis.axis.begin() + 1, add_axis.axis.end()); | ||||
return imperative::apply( | |||||
ApplyOp(*AddAxis::make(axis, add_axis.scope())), | |||||
scalar_input->value()); | |||||
return imperative::apply(*AddAxis::make(axis, add_axis.scope()), inputs[0]); | |||||
} | } | ||||
} else { | } else { | ||||
return imperative::apply(add_axis, inputs); | return imperative::apply(add_axis, inputs); | ||||
} | } | ||||
} | } | ||||
std::vector<ValueRef> remote_recv_rule( | |||||
const RemoteRecv& remote_recv, Span<ValueRef> inputs) { | |||||
ValueRefList remote_recv_rule( | |||||
const RemoteRecv& remote_recv, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||||
if (remote_recv.shape.empty()) { | if (remote_recv.shape.empty()) { | ||||
std::vector<int32_t> shape = {1}; | std::vector<int32_t> shape = {1}; | ||||
auto remote_recv_no_scalar = RemoteRecv::make( | auto remote_recv_no_scalar = RemoteRecv::make( | ||||
@@ -189,32 +158,32 @@ std::vector<ValueRef> remote_recv_rule( | |||||
remote_recv.rank_from, remote_recv.cn, shape, remote_recv.dtype, | remote_recv.rank_from, remote_recv.cn, shape, remote_recv.dtype, | ||||
remote_recv.backend); | remote_recv.backend); | ||||
remote_recv_no_scalar->set_scope(remote_recv.scope()); | remote_recv_no_scalar->set_scope(remote_recv.scope()); | ||||
return imperative::apply( | |||||
ApplyOp(*remote_recv_no_scalar), unwrap_inputs(inputs)); | |||||
return imperative::apply(ApplyOp(*remote_recv_no_scalar), inputs); | |||||
} else { | } else { | ||||
return imperative::apply(remote_recv, unwrap_inputs(inputs)); | |||||
return imperative::apply(remote_recv, inputs); | |||||
} | } | ||||
} | } | ||||
std::vector<ValueRef> check_no_finite_rule( | |||||
const CheckNonFinite& check_no_finite, Span<ValueRef> inputs) { | |||||
auto outputs = imperative::apply(check_no_finite, unwrap_inputs(inputs)); | |||||
ValueRefList check_no_finite_rule( | |||||
const CheckNonFinite& check_no_finite, Span<ValueRef> inputs, | |||||
Span<bool> inputs_mask) { | |||||
auto outputs = imperative::apply(check_no_finite, inputs); | |||||
mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch"); | mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch"); | ||||
outputs.back() = ScalarValue::make(outputs.back()); | outputs.back() = ScalarValue::make(outputs.back()); | ||||
for (size_t i = 0; i < inputs.size(); ++i) { | for (size_t i = 0; i < inputs.size(); ++i) { | ||||
if (inputs[i].is<ScalarValue>()) { | |||||
if (inputs_mask[i]) { | |||||
outputs[i] = ScalarValue::make(outputs[i]); | outputs[i] = ScalarValue::make(outputs[i]); | ||||
} | } | ||||
} | } | ||||
return outputs; | return outputs; | ||||
} | } | ||||
std::vector<ValueRef> subtensor_rule( | |||||
const Subtensor& subtensor, Span<ValueRef> inputs) { | |||||
ValueRefList subtensor_rule( | |||||
const Subtensor& subtensor, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||||
mgb_assert(inputs.size() >= 1); | mgb_assert(inputs.size() >= 1); | ||||
auto input = inputs[0]; | auto input = inputs[0]; | ||||
bool is_scalar; | bool is_scalar; | ||||
mgb_assert(!input.is<ScalarValue>(), "subtensor shouldn't have scalar input"); | |||||
mgb_assert(!inputs_mask[0], "subtensor shouldn't have scalar input"); | |||||
if (auto shape = input.shape()) { | if (auto shape = input.shape()) { | ||||
size_t ndim = input.shape()->ndim; | size_t ndim = input.shape()->ndim; | ||||
for (auto&& [axis, begin, end, step, idx] : subtensor.items) { | for (auto&& [axis, begin, end, step, idx] : subtensor.items) { | ||||
@@ -226,25 +195,25 @@ std::vector<ValueRef> subtensor_rule( | |||||
} else { | } else { | ||||
is_scalar = false; | is_scalar = false; | ||||
} | } | ||||
auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0]; | |||||
auto outputs = imperative::apply(subtensor, inputs); | |||||
if (is_scalar) { | if (is_scalar) { | ||||
return {ScalarValue::make(output)}; | |||||
} else { | |||||
return {output}; | |||||
outputs[0] = ScalarValue::make(outputs[0]); | |||||
} | } | ||||
return outputs; | |||||
} | } | ||||
std::vector<ValueRef> get_var_shape_rule( | |||||
const GetVarShape& get_var_shape, Span<ValueRef> inputs) { | |||||
ValueRefList get_var_shape_rule( | |||||
const GetVarShape& get_var_shape, Span<ValueRef> inputs, | |||||
Span<bool> inputs_mask) { | |||||
bool all_scalar = true; | bool all_scalar = true; | ||||
mgb_assert(inputs.size() >= 1); | mgb_assert(inputs.size() >= 1); | ||||
for (auto&& input : inputs) { | |||||
if (!input.is<ScalarValue>()) { | |||||
for (auto&& input_mask : inputs_mask) { | |||||
if (!input_mask) { | |||||
all_scalar = false; | all_scalar = false; | ||||
} | } | ||||
} | } | ||||
if (all_scalar) { | if (all_scalar) { | ||||
auto device = inputs[0].cast<ScalarValue>().value().device(); | |||||
auto device = inputs[0].device(); | |||||
auto storage = HostStorage::make(*device); | auto storage = HostStorage::make(*device); | ||||
// storage->ensure_size(1); | // storage->ensure_size(1); | ||||
return imperative::apply( | return imperative::apply( | ||||
@@ -252,88 +221,49 @@ std::vector<ValueRef> get_var_shape_rule( | |||||
CreateTensor::Const, *device, dtype::Int32(), ValueShape{0}), | CreateTensor::Const, *device, dtype::Int32(), ValueShape{0}), | ||||
storage); | storage); | ||||
} else { | } else { | ||||
return imperative::apply(get_var_shape, unwrap_inputs(inputs)); | |||||
} | |||||
} | |||||
std::vector<ValueRef> fastpath_copy_rule( | |||||
const FastpathCopy& fastpath_copy, Span<ValueRef> inputs) { | |||||
mgb_assert(inputs.size() == 1); | |||||
bool is_scalar = inputs[0].is<ScalarValue>(); | |||||
auto output = imperative::apply(fastpath_copy, unwrap_inputs(inputs))[0]; | |||||
if (is_scalar) { | |||||
return {ScalarValue::make(output)}; | |||||
} else { | |||||
return {output}; | |||||
return imperative::apply(get_var_shape, inputs); | |||||
} | } | ||||
} | } | ||||
std::vector<ValueRef> reshape_rule(const Reshape& reshape, Span<ValueRef> inputs) { | |||||
ValueRefList reshape_rule( | |||||
const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
bool is_scalar = is_scalar_shape(inputs[1]); | bool is_scalar = is_scalar_shape(inputs[1]); | ||||
auto unwrapped_input = inputs[0].is<ScalarValue>() | |||||
? inputs[0].cast<ScalarValue>().value() | |||||
: inputs[0]; | |||||
if (is_scalar) { | if (is_scalar) { | ||||
return {ScalarValue::make(imperative::apply( | return {ScalarValue::make(imperative::apply( | ||||
reshape, unwrapped_input, | |||||
make_scalar_shape(*unwrapped_input.device()))[0])}; | |||||
reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | |||||
} else { | } else { | ||||
return imperative::apply(reshape, unwrap_inputs(inputs)); | |||||
return imperative::apply(reshape, inputs); | |||||
} | } | ||||
} | } | ||||
std::vector<ValueRef> broadcast_rule( | |||||
const Broadcast& broadcast, Span<ValueRef> inputs) { | |||||
ValueRefList broadcast_rule( | |||||
const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
bool is_scalar = is_scalar_shape(inputs[1]); | bool is_scalar = is_scalar_shape(inputs[1]); | ||||
auto unwrapped_input = inputs[0].is<ScalarValue>() | |||||
? inputs[0].cast<ScalarValue>().value() | |||||
: inputs[0]; | |||||
if (is_scalar) { | if (is_scalar) { | ||||
return {ScalarValue::make(imperative::apply( | return {ScalarValue::make(imperative::apply( | ||||
broadcast, unwrapped_input, | |||||
make_scalar_shape(*unwrapped_input.device()))[0])}; | |||||
} else { | |||||
return imperative::apply(broadcast, unwrap_inputs(inputs)); | |||||
} | |||||
} | |||||
std::vector<ValueRef> copy_rule(const Copy& copy, Span<ValueRef> inputs) { | |||||
mgb_assert(inputs.size() == 1); | |||||
bool is_scalar = inputs[0].is<ScalarValue>(); | |||||
if (is_scalar) { | |||||
return {ScalarValue::make(imperative::apply(copy, unwrap_inputs(inputs))[0])}; | |||||
} else { | |||||
return imperative::apply(copy, unwrap_inputs(inputs)); | |||||
} | |||||
} | |||||
std::vector<ValueRef> inplace_add_rule( | |||||
const InplaceAdd& inplace_add, Span<ValueRef> inputs) { | |||||
mgb_assert(inputs.size() == 4); | |||||
bool is_scalar = inputs[0].is<ScalarValue>(); | |||||
if (is_scalar) { | |||||
return {ScalarValue::make( | |||||
imperative::apply(inplace_add, unwrap_inputs(inputs))[0])}; | |||||
broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | |||||
} else { | } else { | ||||
return imperative::apply(inplace_add, unwrap_inputs(inputs)); | |||||
return imperative::apply(broadcast, inputs); | |||||
} | } | ||||
} | } | ||||
template <typename T> | template <typename T> | ||||
std::vector<ValueRef> subgraph_op_rule(const T& op, Span<ValueRef> inputs) { | |||||
ValueRefList subgraph_op_rule( | |||||
const T& op, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||||
const Type<ScalarValue>& scalar_type) { | |||||
// TODO: add flag instead of assume | // TODO: add flag instead of assume | ||||
bool all_scalar = true; | bool all_scalar = true; | ||||
for (auto&& input : inputs) { | |||||
if (!input.is<ScalarValue>()) { | |||||
for (auto&& input_mask : inputs_mask) { | |||||
if (!input_mask) { | |||||
all_scalar = false; | all_scalar = false; | ||||
} | } | ||||
} | } | ||||
auto outputs = imperative::apply(op, unwrap_inputs(inputs)); | |||||
auto outputs = imperative::apply(op, inputs); | |||||
if (all_scalar) { | if (all_scalar) { | ||||
for (auto& output : outputs) { | for (auto& output : outputs) { | ||||
output = ScalarValue::make(output); | |||||
output = scalar_type.make(output); | |||||
} | } | ||||
} | } | ||||
return outputs; | return outputs; | ||||
@@ -341,67 +271,54 @@ std::vector<ValueRef> subgraph_op_rule(const T& op, Span<ValueRef> inputs) { | |||||
struct ScalarRuleRegistry { | struct ScalarRuleRegistry { | ||||
ScalarRuleRegistry() { | ScalarRuleRegistry() { | ||||
register_scalar_rule(elemwise_rule); | |||||
register_scalar_rule(remove_axis_rule); | |||||
register_scalar_rule(reduce_rule); | |||||
register_scalar_rule(typecvt_rule); | |||||
register_scalar_rule(collective_comm_rule); | |||||
register_scalar_rule(param_pack_split_rule); | |||||
register_scalar_rule(dot_rule); | |||||
register_scalar_rule(add_axis_rule); | |||||
register_scalar_rule(remote_recv_rule); | |||||
register_scalar_rule(check_no_finite_rule); | |||||
register_scalar_rule(subtensor_rule); | |||||
register_scalar_rule(get_var_shape_rule); | |||||
register_scalar_rule(fastpath_copy_rule); | |||||
register_scalar_rule(reshape_rule); | |||||
register_scalar_rule(broadcast_rule); | |||||
register_scalar_rule(copy_rule); | |||||
register_scalar_rule(inplace_add_rule); | |||||
register_scalar_rule(subgraph_op_rule<SubgraphOp>); | |||||
register_scalar_rule(subgraph_op_rule<CompiledOp>); | |||||
register_scalar_rule<Elemwise, elemwise_rule<Elemwise, 0>>(); | |||||
register_scalar_rule<RemoveAxis, remove_axis_rule>(); | |||||
register_scalar_rule<Reduce, reduce_rule>(); | |||||
register_scalar_rule<TypeCvt, elemwise_rule<TypeCvt, 1>>(); | |||||
register_scalar_rule<CollectiveComm, collective_comm_rule>(); | |||||
register_scalar_rule<ParamPackSplit, param_pack_split_rule>(); | |||||
register_scalar_rule<Dot, dot_rule>(); | |||||
register_scalar_rule<AddAxis, add_axis_rule>(); | |||||
register_scalar_rule<RemoteRecv, remote_recv_rule>(); | |||||
register_scalar_rule<CheckNonFinite, check_no_finite_rule>(); | |||||
register_scalar_rule<Subtensor, subtensor_rule>(); | |||||
register_scalar_rule<GetVarShape, get_var_shape_rule>(); | |||||
register_scalar_rule<FastpathCopy, elemwise_rule<FastpathCopy, 1>>(); | |||||
register_scalar_rule<Reshape, reshape_rule>(); | |||||
register_scalar_rule<Broadcast, broadcast_rule>(); | |||||
register_scalar_rule<Copy, elemwise_rule<Copy, 1>>(); | |||||
register_scalar_rule<InplaceAdd, elemwise_rule<InplaceAdd, 4>>(); | |||||
register_scalar_rule<SubgraphOp, subgraph_op_rule<SubgraphOp>>(); | |||||
register_scalar_rule<CompiledOp, subgraph_op_rule<CompiledOp>>(); | |||||
} | } | ||||
} _; | } _; | ||||
} // namespace | } // namespace | ||||
std::vector<ValueRef> ScalarTransformation::apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) { | |||||
if (auto apply_op = op.as<ApplyOp>()) { | |||||
auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); | |||||
if (iter != scalar_rules.end()) { | |||||
return iter->second(apply_op->op(), inputs); | |||||
} else { | |||||
// TODO: repeat op | |||||
return imperative::apply(op, unwrap_inputs(inputs)); | |||||
} | |||||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | |||||
if (create_tensor->shape().is_scalar()) { | |||||
ValueShape scalar_shape = {1}; | |||||
CreateTensor scalar_op( | |||||
create_tensor->kind(), create_tensor->device(), | |||||
create_tensor->dtype(), scalar_shape); | |||||
return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])}; | |||||
} else { | |||||
return imperative::apply(op, inputs); | |||||
} | |||||
} else if (auto* get_attr = op.as<GetAttr>()) { | |||||
bool is_scalar = inputs.as_array<1>()[0].is<ScalarValue>(); | |||||
auto output = imperative::apply(op, unwrap_inputs(inputs))[0]; | |||||
if (!is_scalar) { | |||||
return {output}; | |||||
ValueRefList ScalarTransformation::apply_get_attr( | |||||
const GetAttr& get_attr, Span<ValueRef> inputs) { | |||||
auto&& input = inputs.item(); | |||||
bool is_scalar = input.is<ScalarValue>(); | |||||
if (!is_scalar) { | |||||
return imperative::apply(get_attr, input); | |||||
} | |||||
auto unwrapped_input = input.cast<ScalarValue>().value(); | |||||
if (get_attr.attr() == GetAttr::Shape) { | |||||
if (!m_empty_shape) { | |||||
m_empty_shape = ShapeValue::make(); | |||||
} | } | ||||
switch (get_attr->attr()) { | |||||
case GetAttr::Shape: { | |||||
// Scalar Shape | |||||
return {ShapeValue::make()}; | |||||
} | |||||
return {m_empty_shape}; | |||||
} else { | |||||
auto outputs = imperative::apply(get_attr, unwrapped_input); | |||||
auto& output = outputs[0]; | |||||
switch (get_attr.attr()) { | |||||
case GetAttr::Value: { | case GetAttr::Value: { | ||||
auto& hv = output.cast<HostValue>(); | auto& hv = output.cast<HostValue>(); | ||||
mgb_assert( | mgb_assert( | ||||
hv.shape() == ValueShape({1}), | hv.shape() == ValueShape({1}), | ||||
"underlying value should has shape {1}, got %s", | "underlying value should has shape {1}, got %s", | ||||
hv.shape().to_string().c_str()); | hv.shape().to_string().c_str()); | ||||
return {HostValue::make(hv.dtype(), ValueShape(), hv.storage())}; | |||||
output = HostValue::make(hv.dtype(), ValueShape(), hv.storage()); | |||||
break; | |||||
} | } | ||||
case GetAttr::Data: { | case GetAttr::Data: { | ||||
auto& dv = output.cast<DeviceValue>(); | auto& dv = output.cast<DeviceValue>(); | ||||
@@ -409,22 +326,67 @@ std::vector<ValueRef> ScalarTransformation::apply_transformation( | |||||
dv.shape() == ValueShape({1}), | dv.shape() == ValueShape({1}), | ||||
"underlying value should has shape {1}, got %s", | "underlying value should has shape {1}, got %s", | ||||
dv.shape().to_string().c_str()); | dv.shape().to_string().c_str()); | ||||
return {DeviceValue::make(dv.dtype(), ValueShape(), dv.storage())}; | |||||
output = DeviceValue::make(dv.dtype(), ValueShape(), dv.storage()); | |||||
break; | |||||
} | } | ||||
default: | default: | ||||
return {output}; | |||||
break; | |||||
} | |||||
return outputs; | |||||
} | |||||
} | |||||
ValueRefList ScalarTransformation::apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) { | |||||
if (auto* get_attr = op.as<GetAttr>()) { | |||||
// fastpath for GetAttr | |||||
return apply_get_attr(*get_attr, inputs); | |||||
} | |||||
size_t nr_inputs = inputs.size(); | |||||
ValueRefList unwrapped_inputs(nr_inputs); | |||||
bool inputs_mask[nr_inputs]; | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (auto scalar_value = inputs[i].as_ref<ScalarValue>()) { | |||||
unwrapped_inputs[i] = scalar_value->value(); | |||||
inputs_mask[i] = true; | |||||
} else { | |||||
unwrapped_inputs[i] = inputs[i]; | |||||
inputs_mask[i] = false; | |||||
} | |||||
} | |||||
auto fallback = [&] { return imperative::apply(op, unwrapped_inputs); }; | |||||
if (auto apply_op = op.as<ApplyOp>()) { | |||||
auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); | |||||
if (iter != scalar_rules.end()) { | |||||
return iter->second( | |||||
apply_op->op(), unwrapped_inputs, {inputs_mask, nr_inputs}); | |||||
} else { | |||||
// TODO: repeat op | |||||
return fallback(); | |||||
} | |||||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | |||||
if (create_tensor->shape().is_scalar()) { | |||||
ValueShape scalar_shape = {1}; | |||||
CreateTensor scalar_op( | |||||
create_tensor->kind(), create_tensor->device(), | |||||
create_tensor->dtype(), scalar_shape); | |||||
return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])}; | |||||
} else { | |||||
return imperative::apply(op, inputs); | |||||
} | } | ||||
} else if (op.as<IsScalar>()) { | } else if (op.as<IsScalar>()) { | ||||
return {BoolValue::make(inputs.as_array<1>()[0].is<ScalarValue>())}; | |||||
mgb_assert(nr_inputs == 1); | |||||
return {BoolValue::make(inputs_mask[0])}; | |||||
} else if (op.is<Operator::IdentityLike>()) { | } else if (op.is<Operator::IdentityLike>()) { | ||||
bool is_scalar = inputs.as_array<1>()[0].is<ScalarValue>(); | |||||
mgb_assert(nr_inputs == 1); | |||||
bool is_scalar = inputs_mask[0]; | |||||
auto outputs = fallback(); | |||||
if (is_scalar) { | if (is_scalar) { | ||||
return {ScalarValue::make(imperative::apply(op, unwrap_inputs(inputs))[0])}; | |||||
} else { | |||||
return imperative::apply(op, inputs); | |||||
outputs[0] = ScalarValue::make(outputs[0]); | |||||
} | } | ||||
return outputs; | |||||
} else { | } else { | ||||
return imperative::apply(op, unwrap_inputs(inputs)); | |||||
return fallback(); | |||||
} | } | ||||
}; | }; | ||||
@@ -0,0 +1,25 @@ | |||||
/** | |||||
* \file imperative/src/impl/transformations/tangent.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/imperative/transformations/tangent.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
ValueRefList TangentTransformation::apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) { | |||||
if (auto* apply_op = op.as<ApplyOp>()) { | |||||
} | |||||
mgb_assert(false); | |||||
} | |||||
} // namespace imperative | |||||
} // namespace mgb |
@@ -153,7 +153,7 @@ VarNodeArray TraceResult::dump( | |||||
return output_nodes; | return output_nodes; | ||||
} | } | ||||
std::vector<ValueRef> TracingTransformation::apply_transformation( | |||||
ValueRefList TracingTransformation::apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) { | const Operator& op, Span<ValueRef> inputs) { | ||||
if (auto* op_value = op.as<ApplyOp>()) { | if (auto* op_value = op.as<ApplyOp>()) { | ||||
SmallVector<ValueRef> unwrapped_inputs; | SmallVector<ValueRef> unwrapped_inputs; | ||||
@@ -180,11 +180,12 @@ std::vector<ValueRef> TracingTransformation::apply_transformation( | |||||
} | } | ||||
const_cast<OpDef&>(op_value->op()).set_scope(scopes_join); | const_cast<OpDef&>(op_value->op()).set_scope(scopes_join); | ||||
auto unwrapped_outputs = imperative::apply(op, unwrapped_inputs); | auto unwrapped_outputs = imperative::apply(op, unwrapped_inputs); | ||||
std::vector<ValueRef> wrapped_outputs; | |||||
ValueRefList wrapped_outputs(unwrapped_outputs.size()); | |||||
SmallVector<size_t> output_ids; | SmallVector<size_t> output_ids; | ||||
for (auto&& output : unwrapped_outputs) { | |||||
for (size_t i = 0; i < unwrapped_outputs.size(); ++i) { | |||||
auto&& output = unwrapped_outputs[i]; | |||||
auto wrapped_output = record_var(output, false, VarKind::Internal); | auto wrapped_output = record_var(output, false, VarKind::Internal); | ||||
wrapped_outputs.push_back(wrapped_output); | |||||
wrapped_outputs[i] = wrapped_output; | |||||
output_ids.push_back(wrapped_output->id()); | output_ids.push_back(wrapped_output->id()); | ||||
} | } | ||||
m_seq.push_back({op_value->op().shared_from_this(), input_ids, output_ids}); | m_seq.push_back({op_value->op().shared_from_this(), input_ids, output_ids}); | ||||
@@ -375,6 +376,11 @@ void CompiledTransformation::compile() { | |||||
return accessor; | return accessor; | ||||
}; | }; | ||||
std::vector<VarAccessor> var_accessors(m_vars.size()); | std::vector<VarAccessor> var_accessors(m_vars.size()); | ||||
auto exc_setter = std::bind( | |||||
&CompiledTransformation::set_exception, this, std::placeholders::_1); | |||||
for (auto&& accessor : var_accessors) { | |||||
accessor.exc_setter = exc_setter; | |||||
} | |||||
for (auto&& item : m_seq) { | for (auto&& item : m_seq) { | ||||
bool require_link = bool(item.op) && mm_io_ops.count(item.op->dyn_typeinfo()); | bool require_link = bool(item.op) && mm_io_ops.count(item.op->dyn_typeinfo()); | ||||
VarNodeArray input_vars; | VarNodeArray input_vars; | ||||
@@ -509,8 +515,8 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { | |||||
} | } | ||||
} | } | ||||
TracedValue::ref_t CompiledTransformation::trace_output(size_t id) { | |||||
auto traced_value = TracedValue::make(id); | |||||
auto CompiledTransformation::trace_output(size_t id) -> TracedValue::ref_t { | |||||
auto traced_value = TracedValue::make(id, &m_vars[id], &m_var_accessors[id]); | |||||
m_weak_values.push_back(traced_value); | m_weak_values.push_back(traced_value); | ||||
return traced_value; | return traced_value; | ||||
} | } | ||||
@@ -520,64 +526,99 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() { | |||||
return m_seq[m_pc++]; | return m_seq[m_pc++]; | ||||
} | } | ||||
std::vector<ValueRef> CompiledTransformation::apply_transformation( | |||||
ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const { | |||||
if (!m_shape) { | |||||
trace_assert(m_accessor->shape_getter, "shape unreadable"); | |||||
m_shape = ShapeValue::make(ValueShape::from(m_accessor->shape_getter())); | |||||
} | |||||
return m_shape; | |||||
} | |||||
DTypeValue::ref_t CompiledTransformation::TracedInfo::dtype() const { | |||||
if (!m_dtype) { | |||||
m_dtype = DTypeValue::make(m_var->dtype); | |||||
} | |||||
return m_dtype; | |||||
} | |||||
CompNodeValue::ref_t CompiledTransformation::TracedInfo::comp_node() const { | |||||
if (!m_comp_node) { | |||||
m_comp_node = CompNodeValue::make(m_var->device); | |||||
} | |||||
return m_comp_node; | |||||
} | |||||
auto CompiledTransformation::TracedInfo::accessor() const -> const VarAccessor& { | |||||
return *m_accessor; | |||||
} | |||||
ValueRefList CompiledTransformation::apply_op( | |||||
const ApplyOp& apply_op, Span<ValueRef> inputs) { | |||||
auto& item = next_instruction(); | |||||
trace_assert(inputs.size() == item.inputs.size(), "input size mismatch"); | |||||
trace_assert(apply_op.op().is_same(*item.op), "operator mismatch"); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
trace_input(item.inputs[i], inputs[i]); | |||||
} | |||||
ValueRefList outputs(item.outputs.size()); | |||||
for (size_t i = 0; i < item.outputs.size(); ++i) { | |||||
outputs[i] = trace_output(item.outputs[i]); | |||||
} | |||||
return outputs; | |||||
} | |||||
ValueRefList CompiledTransformation::apply_get_attr( | |||||
const GetAttr& get_attr, Span<ValueRef> inputs) { | |||||
if (auto* traced_value = inputs[0].as<TracedValue>()) { | |||||
ValueRef output; | |||||
auto& var_accessor = traced_value->accessor(); | |||||
switch (get_attr.attr()) { | |||||
case GetAttr::Shape: | |||||
output = traced_value->shape(); | |||||
break; | |||||
case GetAttr::Data: | |||||
trace_assert(var_accessor.data_getter, "data unreadable"); | |||||
output = DeviceValue::make(var_accessor.data_getter()); | |||||
break; | |||||
case GetAttr::Value: | |||||
trace_assert(var_accessor.value_getter, "value unreadable"); | |||||
output = HostValue::make(var_accessor.value_getter()); | |||||
break; | |||||
case GetAttr::DType: | |||||
output = traced_value->dtype(); | |||||
break; | |||||
case GetAttr::Device: | |||||
output = traced_value->comp_node(); | |||||
default: | |||||
break; | |||||
} | |||||
return {output}; | |||||
} else { | |||||
return imperative::apply(get_attr, inputs); | |||||
} | |||||
} | |||||
ValueRefList CompiledTransformation::apply_create_tensor( | |||||
const CreateTensor& create_tensor, Span<ValueRef> inputs) { | |||||
if (create_tensor.kind() == CreateTensor::NoTrace) { | |||||
return imperative::apply(create_tensor, inputs); | |||||
} | |||||
auto& item = next_instruction(); | |||||
trace_assert(item.op == nullptr, "operator mismatch"); | |||||
auto input_id = item.inputs[0]; | |||||
auto output_id = item.outputs[0]; | |||||
auto tensor = imperative::apply(create_tensor, inputs)[0]; | |||||
trace_input(input_id, tensor); | |||||
return {trace_output(output_id)}; | |||||
} | |||||
ValueRefList CompiledTransformation::apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) { | const Operator& op, Span<ValueRef> inputs) { | ||||
if (auto* op_value = op.as<ApplyOp>()) { | if (auto* op_value = op.as<ApplyOp>()) { | ||||
auto& item = next_instruction(); | |||||
SmallVector<ValueRef> unwrapped_inputs; | |||||
SmallVector<ValueRef> wrapped_inputs; | |||||
trace_assert(inputs.size() == item.inputs.size(), "input size mismatch"); | |||||
trace_assert(op_value->op().is_same(*item.op), "operator mismatch"); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
trace_input(item.inputs[i], inputs[i]); | |||||
} | |||||
std::vector<ValueRef> outputs; | |||||
for (auto&& output_id : item.outputs) { | |||||
outputs.push_back(trace_output(output_id)); | |||||
} | |||||
return outputs; | |||||
return apply_op(*op_value, inputs); | |||||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | } else if (auto* create_tensor = op.as<CreateTensor>()) { | ||||
if (create_tensor->kind() == CreateTensor::NoTrace) { | |||||
return imperative::apply(op, inputs); | |||||
} | |||||
auto& item = next_instruction(); | |||||
trace_assert(item.op == nullptr, "operator mismatch"); | |||||
auto input_id = item.inputs[0]; | |||||
auto output_id = item.outputs[0]; | |||||
auto tensor = imperative::apply(op, inputs)[0]; | |||||
trace_input(input_id, tensor); | |||||
return {trace_output(output_id)}; | |||||
return apply_create_tensor(*create_tensor, inputs); | |||||
} else if (auto* get_attr = op.as<GetAttr>()) { | } else if (auto* get_attr = op.as<GetAttr>()) { | ||||
if (auto* traced_value = inputs[0].as<TracedValue>()) { | |||||
ValueRef output; | |||||
auto& var = m_vars[traced_value->id()]; | |||||
auto& var_accessor = m_var_accessors[traced_value->id()]; | |||||
switch (get_attr->attr()) { | |||||
case GetAttr::Shape: | |||||
trace_assert(var_accessor.shape_getter, "shape unreadable"); | |||||
output = ShapeValue::make( | |||||
ValueShape::from(var_accessor.shape_getter())); | |||||
break; | |||||
case GetAttr::Data: | |||||
trace_assert(var_accessor.data_getter, "data unreadable"); | |||||
output = DeviceValue::make(var_accessor.data_getter()); | |||||
break; | |||||
case GetAttr::Value: | |||||
trace_assert(var_accessor.value_getter, "value unreadable"); | |||||
output = HostValue::make(var_accessor.value_getter()); | |||||
break; | |||||
case GetAttr::DType: | |||||
output = DTypeValue::make(var.dtype); | |||||
break; | |||||
case GetAttr::Device: | |||||
output = CompNodeValue::make(var.device); | |||||
default: | |||||
break; | |||||
} | |||||
return {output}; | |||||
} else { | |||||
return imperative::apply(op, inputs); | |||||
} | |||||
return apply_get_attr(*get_attr, inputs); | |||||
} else if (auto* trace_mark_var = op.as<TraceMarkVar>()) { | } else if (auto* trace_mark_var = op.as<TraceMarkVar>()) { | ||||
auto& item = next_instruction(); | auto& item = next_instruction(); | ||||
trace_assert(item.op == nullptr, "operator mismatch"); | trace_assert(item.op == nullptr, "operator mismatch"); | ||||
@@ -8,50 +8,58 @@ namespace mgb { | |||||
namespace imperative { | namespace imperative { | ||||
namespace { | namespace { | ||||
static thread_local size_t nr_watched_values = 0; | |||||
static thread_local uint64_t nr_values = 0; | |||||
static thread_local bool recording_values = false; | |||||
static thread_local std::vector<ValueWeakRef> recorded_values; | |||||
static /*thread_local*/ size_t nr_watched_values = 0; | |||||
static /*thread_local*/ uint64_t nr_values = 0; | |||||
static /*thread_local*/ bool recording_values = false; | |||||
static /*thread_local*/ std::vector<ValueWeakRef> recorded_values; | |||||
static WeakValueMap<uint64_t, ValueWeakRef> registered_values; | static WeakValueMap<uint64_t, ValueWeakRef> registered_values; | ||||
} // namespace | } // namespace | ||||
ValueRef::storage_t& ValueRef::storage() const { | ValueRef::storage_t& ValueRef::storage() const { | ||||
if (!m_storage) { | |||||
if (mgb_likely(!m_storage->m_successor.m_storage)) { | |||||
return m_storage; | return m_storage; | ||||
} | } | ||||
if (auto& storage = m_storage->m_successor.m_storage) { | |||||
while (storage->m_successor.m_storage) { | |||||
storage = storage->m_successor.m_storage; | |||||
} | |||||
return storage; | |||||
} else { | |||||
return m_storage; | |||||
while (m_storage->m_successor.m_storage) { | |||||
m_storage = m_storage->m_successor.m_storage; | |||||
} | |||||
return m_storage; | |||||
} | |||||
const Value* ValueRef::as(size_t typecode) const { | |||||
auto&& storage = this->storage(); | |||||
if (storage->m_typecode != typecode) { | |||||
return nullptr; | |||||
} | } | ||||
return static_cast<Value*>(storage.get()); | |||||
} | |||||
bool ValueRef::is(size_t typecode) const { | |||||
return this->storage()->m_typecode == typecode; | |||||
} | } | ||||
TypedValueRef<DeviceValue> ValueRef::dev_tensor() const { | TypedValueRef<DeviceValue> ValueRef::dev_tensor() const { | ||||
return imperative::apply(GetAttr(GetAttr::Data), *this)[0].as_ref<DeviceValue>(); | |||||
return imperative::apply(GetAttr(GetAttr::Data), *this)[0].cast_ref<DeviceValue>(); | |||||
} | } | ||||
TypedValueRef<HostValue> ValueRef::numpy() const { | TypedValueRef<HostValue> ValueRef::numpy() const { | ||||
return imperative::apply(GetAttr(GetAttr::Value), *this)[0].as_ref<HostValue>(); | |||||
return imperative::apply(GetAttr(GetAttr::Value), *this)[0].cast_ref<HostValue>(); | |||||
} | } | ||||
TypedValueRef<CompNodeValue> ValueRef::device() const { | TypedValueRef<CompNodeValue> ValueRef::device() const { | ||||
return imperative::apply(GetAttr(GetAttr::Device), *this)[0] | return imperative::apply(GetAttr(GetAttr::Device), *this)[0] | ||||
.as_ref<CompNodeValue>(); | |||||
.cast_ref<CompNodeValue>(); | |||||
} | } | ||||
TypedValueRef<ShapeValue> ValueRef::shape() const { | TypedValueRef<ShapeValue> ValueRef::shape() const { | ||||
return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].as_ref<ShapeValue>(); | |||||
return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].cast_ref<ShapeValue>(); | |||||
} | } | ||||
TypedValueRef<DTypeValue> ValueRef::dtype() const { | TypedValueRef<DTypeValue> ValueRef::dtype() const { | ||||
return imperative::apply(GetAttr(GetAttr::DType), *this)[0].as_ref<DTypeValue>(); | |||||
return imperative::apply(GetAttr(GetAttr::DType), *this)[0].cast_ref<DTypeValue>(); | |||||
} | } | ||||
TypedValueRef<StringValue> ValueRef::name() const { | TypedValueRef<StringValue> ValueRef::name() const { | ||||
return imperative::apply(GetName(), *this)[0].as_ref<StringValue>(); | |||||
return imperative::apply(GetName(), *this)[0].cast_ref<StringValue>(); | |||||
} | } | ||||
bool ValueRef::is_scalar() const { | bool ValueRef::is_scalar() const { | ||||
@@ -75,13 +83,15 @@ void ValueRef::unwatch() const { | |||||
} | } | ||||
ValueRef ValueRef::unwrap() const { | ValueRef ValueRef::unwrap() const { | ||||
ValueRef value = *this; | |||||
auto& context = Transformation::get_context(); | auto& context = Transformation::get_context(); | ||||
for (size_t i = 0; i < context.next_transformation; ++i) { | |||||
value = context.transformations[i]->unwrap(value); | |||||
if (mgb_unlikely(context.next_transformation)) { | |||||
ValueRef value = *this; | |||||
for (size_t i = 0; i < context.next_transformation; ++i) { | |||||
value = context.transformations[i]->unwrap(value); | |||||
} | |||||
return value; | |||||
} | } | ||||
mgb_assert(value); | |||||
return value; | |||||
return *this; | |||||
} | } | ||||
std::string ValueRef::to_string() const { | std::string ValueRef::to_string() const { | ||||
@@ -101,13 +111,11 @@ std::string ValueRef::raw_type() const { | |||||
return types[m_storage->m_typecode].name(); | return types[m_storage->m_typecode].name(); | ||||
} | } | ||||
uint64_t ValueRef::id() const { | |||||
return m_storage ? m_storage->m_id : std::numeric_limits<uint64_t>::max(); | |||||
} | |||||
bool ValueRef::watching() const { | bool ValueRef::watching() const { | ||||
auto storage = this->storage(); | |||||
return storage && storage->m_watching; | |||||
if (!m_storage) { | |||||
return false; | |||||
} | |||||
return this->storage()->m_watching; | |||||
} | } | ||||
ValueRef ValueRef::make(ValueRef::storage_t storage) { | ValueRef ValueRef::make(ValueRef::storage_t storage) { | ||||
@@ -186,5 +194,96 @@ void Value::try_rethrow() { | |||||
} | } | ||||
} | } | ||||
inline void ValueRefList::init(size_t nr_elems) { | |||||
m_size = nr_elems; | |||||
if (m_size > 0) { | |||||
if (m_size == 1) { | |||||
m_data = inline_storage(); | |||||
} else { | |||||
auto& context = Transformation::get_context(); | |||||
m_data = context.allocator.allocate(m_size); | |||||
} | |||||
for (size_t i = 0; i < m_size; ++i) { | |||||
new (m_data + i) ValueRef(); | |||||
} | |||||
} else { | |||||
m_data = nullptr; | |||||
} | |||||
} | |||||
ValueRefList::ValueRefList(size_t nr_elems) { | |||||
init(nr_elems); | |||||
} | |||||
ValueRefList::ValueRefList(std::initializer_list<ValueRef> values) | |||||
: ValueRefList(values.begin(), values.end()) {} | |||||
ValueRefList::ValueRefList(const ValueRefList& rhs) | |||||
: ValueRefList(rhs.cbegin(), rhs.cend()) {} | |||||
ValueRefList::ValueRefList(ValueRefList&& rhs) : ValueRefList() { | |||||
m_size = rhs.m_size; | |||||
if (rhs.m_data == rhs.inline_storage()) { | |||||
m_data = inline_storage(); | |||||
new (m_data) ValueRef(); | |||||
m_data[0] = std::move(rhs.m_data[0]); | |||||
} else { | |||||
m_data = rhs.m_data; | |||||
rhs.m_data = nullptr; | |||||
rhs.m_size = 0; | |||||
} | |||||
} | |||||
ValueRefList& ValueRefList::operator=(const ValueRefList& rhs) { | |||||
if (this == &rhs) { | |||||
return *this; | |||||
} | |||||
clear(); | |||||
init(rhs.m_size); | |||||
for (size_t i = 0; i < m_size; ++i) { | |||||
m_data[i] = rhs.m_data[i]; | |||||
} | |||||
return *this; | |||||
} | |||||
ValueRefList& ValueRefList::operator=(ValueRefList&& rhs) { | |||||
if (this == &rhs) { | |||||
return *this; | |||||
} | |||||
clear(); | |||||
if (rhs.m_data == rhs.inline_storage()) { | |||||
m_data = inline_storage(); | |||||
new (m_data) ValueRef(); | |||||
m_data[0] = rhs.m_data[0]; | |||||
m_size = 1; | |||||
rhs.clear(); | |||||
} else { | |||||
m_data = rhs.m_data; | |||||
m_size = rhs.m_size; | |||||
rhs.m_data = nullptr; | |||||
rhs.m_size = 0; | |||||
} | |||||
return *this; | |||||
} | |||||
ValueRefList::~ValueRefList() { | |||||
clear(); | |||||
} | |||||
void ValueRefList::clear() { | |||||
for (size_t i = 0; i < m_size; ++i) { | |||||
m_data[i].~ValueRef(); | |||||
} | |||||
if (m_data) { | |||||
if (m_size != 1) { | |||||
Transformation::get_context().allocator.deallocate(m_data, m_size); | |||||
} else { | |||||
mgb_assert(m_data == inline_storage()); | |||||
} | |||||
} | |||||
m_data = nullptr; | |||||
m_size = 0; | |||||
} | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb |
@@ -24,8 +24,6 @@ namespace imperative { | |||||
class GradKey; | class GradKey; | ||||
using GenericFunction = std::function<std::vector<ValueRef>(Span<ValueRef>)>; | |||||
/** | /** | ||||
* \brief apply an OpDef to values | * \brief apply an OpDef to values | ||||
* | * | ||||
@@ -37,7 +35,7 @@ private: | |||||
public: | public: | ||||
ApplyOp(const OpDef& op) : m_op(op) {} | ApplyOp(const OpDef& op) : m_op(op) {} | ||||
const OpDef& op() { return m_op; } | |||||
const OpDef& op() const { return m_op; } | |||||
std::string to_string() const override; | std::string to_string() const override; | ||||
}; | }; | ||||
@@ -106,7 +104,7 @@ public: | |||||
* \param inputs contains host_storage and device_storage | * \param inputs contains host_storage and device_storage | ||||
* \return Args unpacked args | * \return Args unpacked args | ||||
*/ | */ | ||||
Args parse(Span<ValueRef> inputs); | |||||
Args parse(Span<ValueRef> inputs) const; | |||||
Kind kind() const { return m_kind; } | Kind kind() const { return m_kind; } | ||||
CompNode device() const { return m_device; } | CompNode device() const { return m_device; } | ||||
@@ -129,11 +127,11 @@ private: | |||||
public: | public: | ||||
DTRCommand(Kind kind) : m_kind(kind) {} | DTRCommand(Kind kind) : m_kind(kind) {} | ||||
Kind kind() { return m_kind; } | |||||
Kind kind() const { return m_kind; } | |||||
std::string to_string() const override; | std::string to_string() const override; | ||||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { return {}; } | |||||
ValueRefList fallback(Span<ValueRef> inputs) const override { return {}; } | |||||
}; | }; | ||||
// deprecated | // deprecated | ||||
@@ -141,9 +139,7 @@ class GetName final : public OperatorImpl<GetName, Operator::GetAttrLike> { | |||||
public: | public: | ||||
std::string to_string() const override; | std::string to_string() const override; | ||||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||||
return {ValueRef()}; | |||||
} | |||||
ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; } | |||||
}; | }; | ||||
/** | /** | ||||
@@ -161,7 +157,7 @@ public: | |||||
std::string to_string() const override; | std::string to_string() const override; | ||||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||||
ValueRefList fallback(Span<ValueRef> inputs) const override { | |||||
return {inputs.as_array<1>()[0]}; | return {inputs.as_array<1>()[0]}; | ||||
} | } | ||||
}; | }; | ||||
@@ -23,7 +23,7 @@ namespace imperative { | |||||
class GradKey; | class GradKey; | ||||
using GenericFunction = std::function<std::vector<ValueRef>(Span<ValueRef>)>; | |||||
using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>; | |||||
class ShapeValue final : public MixinValueImpl<ShapeValue, ValueShape> { | class ShapeValue final : public MixinValueImpl<ShapeValue, ValueShape> { | ||||
public: | public: | ||||
@@ -97,6 +97,10 @@ public: | |||||
ValueShape shape() const { return m_shape; } | ValueShape shape() const { return m_shape; } | ||||
CompNode device() const { return m_storage.comp_node(); } | CompNode device() const { return m_storage.comp_node(); } | ||||
HostTensorStorage storage() const { return m_storage; } | HostTensorStorage storage() const { return m_storage; } | ||||
DTypeScalar item() const { | |||||
mgb_assert(m_shape.is_scalar()); | |||||
return DTypeScalar::make_from_raw(m_dtype, m_storage.ptr()); | |||||
} | |||||
HostTensorND as_nd(bool allow_scalar = false) const; | HostTensorND as_nd(bool allow_scalar = false) const; | ||||
}; | }; | ||||
@@ -36,11 +36,11 @@ namespace imperative { | |||||
* | * | ||||
* \param op | * \param op | ||||
* \param inputs | * \param inputs | ||||
* \return std::vector<ValueRef> | |||||
* \return ValueRefList | |||||
*/ | */ | ||||
std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs); | |||||
std::vector<ValueRef> apply(const OpDef& def, Span<ValueRef> inputs); | |||||
std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs); | |||||
ValueRefList apply(const Operator& op, Span<ValueRef> inputs); | |||||
ValueRefList apply(const OpDef& def, Span<ValueRef> inputs); | |||||
ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs); | |||||
template <typename... TArgs> | template <typename... TArgs> | ||||
constexpr bool is_all_value_ref_v = | constexpr bool is_all_value_ref_v = | ||||
@@ -49,7 +49,7 @@ constexpr bool is_all_value_ref_v = | |||||
template <typename T, typename... TArgs> | template <typename T, typename... TArgs> | ||||
static auto apply(T&& op, TArgs&&... args) | static auto apply(T&& op, TArgs&&... args) | ||||
-> std::enable_if_t<is_all_value_ref_v<TArgs...>, std::vector<ValueRef>> { | |||||
-> std::enable_if_t<is_all_value_ref_v<TArgs...>, ValueRefList> { | |||||
ValueRef args_arr[sizeof...(TArgs)] = {std::forward<TArgs&&>(args)...}; | ValueRef args_arr[sizeof...(TArgs)] = {std::forward<TArgs&&>(args)...}; | ||||
return imperative::apply( | return imperative::apply( | ||||
std::forward<T&&>(op), | std::forward<T&&>(op), | ||||
@@ -63,7 +63,7 @@ static auto apply(T&& op, TContainer&& container) -> std::enable_if_t< | |||||
ValueRef> && | ValueRef> && | ||||
std::is_same_v<decltype(container.size()), size_t> && | std::is_same_v<decltype(container.size()), size_t> && | ||||
!std::is_same_v<std::decay_t<TContainer>, Span<ValueRef>>, | !std::is_same_v<std::decay_t<TContainer>, Span<ValueRef>>, | ||||
std::vector<ValueRef>> { | |||||
ValueRefList> { | |||||
return imperative::apply( | return imperative::apply( | ||||
std::forward<T&&>(op), Span<ValueRef>(container.data(), container.size())); | std::forward<T&&>(op), Span<ValueRef>(container.data(), container.size())); | ||||
} | } | ||||
@@ -25,6 +25,8 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>; | |||||
/** | /** | ||||
* \brief base class for all operators | * \brief base class for all operators | ||||
* | * | ||||
@@ -49,25 +51,24 @@ public: | |||||
Kind kind() const { return m_kind; } | Kind kind() const { return m_kind; } | ||||
template <typename U> | template <typename U> | ||||
U* as() const { | |||||
const U* as() const { | |||||
if (m_typecode != U::TYPE_CODE) { | if (m_typecode != U::TYPE_CODE) { | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
return static_cast<U*>(const_cast<Operator*>(this)); | |||||
return static_cast<const U*>(this); | |||||
} | } | ||||
template <typename U> | template <typename U> | ||||
bool is() const { | bool is() const { | ||||
return as<U>() != nullptr; | |||||
return m_typecode == U::TYPE_CODE; | |||||
} | } | ||||
template <Kind kKind> | template <Kind kKind> | ||||
bool is() const { | bool is() const { | ||||
return kind() == kKind; | return kind() == kKind; | ||||
} | } | ||||
template <typename U> | template <typename U> | ||||
U& cast() const { | |||||
U* ptr = as<U>(); | |||||
mgb_assert(ptr); | |||||
return *ptr; | |||||
const U& cast() const { | |||||
mgb_assert(m_typecode == U::TYPE_CODE); | |||||
return static_cast<const U&>(*this); | |||||
} | } | ||||
virtual std::string to_string() const = 0; | virtual std::string to_string() const = 0; | ||||
@@ -77,9 +78,9 @@ public: | |||||
* implementation. | * implementation. | ||||
* | * | ||||
* \param inputs | * \param inputs | ||||
* \return std::vector<ValueRef> | |||||
* \return ValueRefList | |||||
*/ | */ | ||||
virtual std::vector<ValueRef> fallback(Span<ValueRef> inputs) const; | |||||
virtual ValueRefList fallback(Span<ValueRef> inputs) const; | |||||
std::type_index type() const { return registered_types()[m_typecode]; } | std::type_index type() const { return registered_types()[m_typecode]; } | ||||
@@ -123,7 +123,6 @@ public: | |||||
template <typename T, typename... TArgs> | template <typename T, typename... TArgs> | ||||
static uint64_t record(TArgs&&... args) { | static uint64_t record(TArgs&&... args) { | ||||
auto& profiler = get_instance(); | auto& profiler = get_instance(); | ||||
// auto& mem_pool = get_mem_pool<T>(); | |||||
if constexpr (sm_debug) { | if constexpr (sm_debug) { | ||||
Status expected = Running; | Status expected = Running; | ||||
mgb_assert(profiler.m_status.compare_exchange_strong(expected, Recording)); | mgb_assert(profiler.m_status.compare_exchange_strong(expected, Recording)); | ||||
@@ -18,6 +18,7 @@ | |||||
#include "megbrain/common.h" | #include "megbrain/common.h" | ||||
#include "megbrain/imperative/subgraph.h" | #include "megbrain/imperative/subgraph.h" | ||||
#include "megbrain/imperative/utils/allocator.h" | |||||
#include "megbrain/imperative/utils/local_ptr.h" | #include "megbrain/imperative/utils/local_ptr.h" | ||||
#include "megbrain/imperative/utils/span.h" | #include "megbrain/imperative/utils/span.h" | ||||
@@ -25,6 +26,7 @@ namespace mgb { | |||||
namespace imperative { | namespace imperative { | ||||
class ValueRef; | class ValueRef; | ||||
class ValueRefList; | |||||
class Operator; | class Operator; | ||||
class Transformation; | class Transformation; | ||||
@@ -43,6 +45,7 @@ struct TransformationContext { | |||||
// TODO: deprecate TransformationGuard, let next_transformation == frames.size() | // TODO: deprecate TransformationGuard, let next_transformation == frames.size() | ||||
size_t next_transformation = 0; | size_t next_transformation = 0; | ||||
std::vector<TransformationFrame> frames; | std::vector<TransformationFrame> frames; | ||||
ForwardAllocator<ValueRef> allocator; | |||||
}; | }; | ||||
/** | /** | ||||
@@ -86,9 +89,9 @@ public: | |||||
* | * | ||||
* \param op | * \param op | ||||
* \param inputs | * \param inputs | ||||
* \return std::vector<ValueRef> | |||||
* \return ValueRefList | |||||
*/ | */ | ||||
virtual std::vector<ValueRef> apply_transformation( | |||||
virtual ValueRefList apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) = 0; | const Operator& op, Span<ValueRef> inputs) = 0; | ||||
virtual ValueRef unwrap(ValueRef value) = 0; | virtual ValueRef unwrap(ValueRef value) = 0; | ||||
@@ -187,11 +190,12 @@ public: | |||||
std::swap(context.transformations, current_context.transformations); | std::swap(context.transformations, current_context.transformations); | ||||
std::swap(context.scopes, current_context.scopes); | std::swap(context.scopes, current_context.scopes); | ||||
std::swap(context.next_transformation, current_context.next_transformation); | std::swap(context.next_transformation, current_context.next_transformation); | ||||
std::swap(context.allocator, current_context.allocator); | |||||
} | } | ||||
static TransformationContext& get_context(); | static TransformationContext& get_context(); | ||||
friend std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs); | |||||
friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs); | |||||
friend class ValueRef; | friend class ValueRef; | ||||
}; | }; | ||||
@@ -23,16 +23,38 @@ public: | |||||
using Handle = interpreter::Interpreter::Handle; | using Handle = interpreter::Interpreter::Handle; | ||||
using Channel = interpreter::Interpreter::Channel; | using Channel = interpreter::Interpreter::Channel; | ||||
class RAIIHandle : public NonCopyableObj { | |||||
private: | |||||
Handle m_handle = nullptr; | |||||
Channel* m_channel = nullptr; | |||||
public: | |||||
RAIIHandle(Handle handle, Channel* channel) | |||||
: m_handle(handle), m_channel(channel) {} | |||||
~RAIIHandle() { m_channel->del(m_handle); } | |||||
Handle handle() const { return m_handle; } | |||||
Channel* channel() const { return m_channel; } | |||||
}; | |||||
private: | private: | ||||
std::shared_ptr<Handle> m_handle = nullptr; | |||||
LocalPtr<RAIIHandle> m_handle; | |||||
std::string m_name; | std::string m_name; | ||||
mutable DTypeValue::ref_t m_dtype; | |||||
mutable CompNodeValue::ref_t m_comp_node; | |||||
mutable ShapeValue::ref_t m_shape; | |||||
public: | public: | ||||
InterpreterInfo() = default; | InterpreterInfo() = default; | ||||
InterpreterInfo(std::shared_ptr<Handle> handle, std::string name = {}) | |||||
InterpreterInfo(LocalPtr<RAIIHandle> handle, std::string name = {}) | |||||
: m_handle(handle), m_name(name) {} | : m_handle(handle), m_name(name) {} | ||||
std::shared_ptr<Handle> handle() const { return m_handle; } | |||||
const LocalPtr<RAIIHandle>& handle() const { return m_handle; } | |||||
DTypeValue::ref_t dtype() const; | |||||
CompNodeValue::ref_t comp_node() const; | |||||
ShapeValue::ref_t shape() const; | |||||
std::string name() const { return m_name; } | std::string name() const { return m_name; } | ||||
}; | }; | ||||
@@ -60,6 +82,7 @@ class InterpreterTransformation final : public Transformation { | |||||
public: | public: | ||||
using Interpreter = interpreter::Interpreter; | using Interpreter = interpreter::Interpreter; | ||||
using Handle = Interpreter::Handle; | using Handle = Interpreter::Handle; | ||||
using SharedHandle = LocalPtr<InterpreterInfo::RAIIHandle>; | |||||
using Channel = Interpreter::Channel; | using Channel = Interpreter::Channel; | ||||
private: | private: | ||||
@@ -71,7 +94,14 @@ public: | |||||
Channel* channel() { return m_channel.get(); } | Channel* channel() { return m_channel.get(); } | ||||
std::vector<ValueRef> apply_transformation( | |||||
ValueRefList apply_op(const ApplyOp& apply_op, Span<ValueRef> inputs); | |||||
ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs); | |||||
ValueRefList apply_create_tensor( | |||||
const CreateTensor& create_tensor, Span<ValueRef> inputs); | |||||
ValueRefList apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) override; | const Operator& op, Span<ValueRef> inputs) override; | ||||
ValueRef unwrap(ValueRef value) override { | ValueRef unwrap(ValueRef value) override { | ||||
@@ -81,14 +111,8 @@ public: | |||||
std::string name() const override { return "InterpreterTransformation"; } | std::string name() const override { return "InterpreterTransformation"; } | ||||
std::shared_ptr<Handle> share_handle(Handle handle) { | |||||
return std::shared_ptr<Handle>( | |||||
new Handle(handle), [channel = m_channel.get()](Handle* ptr) { | |||||
if (ptr) { | |||||
channel->del(*ptr); | |||||
delete ptr; | |||||
} | |||||
}); | |||||
SharedHandle share_handle(Handle handle) { | |||||
return SharedHandle::make(handle, m_channel.get()); | |||||
} | } | ||||
}; | }; | ||||
@@ -34,9 +34,7 @@ struct BackwardGraphWithClosure { | |||||
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph, | std::shared_ptr<OptimizedBackwardGraphResult> backward_graph, | ||||
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs); | std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs); | ||||
void operator()( | |||||
std::vector<ValueRef> grads, | |||||
std::function<void(size_t, ValueRef)> receiver); | |||||
void operator()(ValueRefList grads, std::function<void(size_t, ValueRef)> receiver); | |||||
bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; } | bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; } | ||||
@@ -50,12 +48,11 @@ struct BackwardGraphWithClosure { | |||||
struct CustomBackward; | struct CustomBackward; | ||||
using GradRuleFn = | |||||
std::function<std::vector<ValueRef>(Span<ValueRef> inputs, CustomBackward&)>; | |||||
using GradRuleFn = std::function<ValueRefList(Span<ValueRef> inputs, CustomBackward&)>; | |||||
struct CustomBackward { | struct CustomBackward { | ||||
using BackwardFn = std::function<std::vector<ValueRef>(Span<ValueRef>)>; | |||||
using BackwardRule = std::function<std::optional<std::vector<ValueRef>>( | |||||
using BackwardFn = std::function<ValueRefList(Span<ValueRef>)>; | |||||
using BackwardRule = std::function<std::optional<ValueRefList>( | |||||
const OpDef&, Span<ValueRef>, Span<bool>, CustomBackward&)>; | const OpDef&, Span<ValueRef>, Span<bool>, CustomBackward&)>; | ||||
BackwardFn m_backward; | BackwardFn m_backward; | ||||
SmallVector<bool, 8> m_input_has_grad; | SmallVector<bool, 8> m_input_has_grad; | ||||
@@ -65,9 +62,7 @@ struct CustomBackward { | |||||
SmallVector<OutputAttr> m_output_attrs; | SmallVector<OutputAttr> m_output_attrs; | ||||
public: | public: | ||||
void operator()( | |||||
std::vector<ValueRef> grads, | |||||
std::function<void(size_t, ValueRef)> receiver); | |||||
void operator()(ValueRefList grads, std::function<void(size_t, ValueRef)> receiver); | |||||
bool input_has_grad(size_t i) { return m_input_has_grad[i]; } | bool input_has_grad(size_t i) { return m_input_has_grad[i]; } | ||||
bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; } | bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; } | ||||
@@ -188,7 +183,7 @@ public: | |||||
std::string to_string() const override; | std::string to_string() const override; | ||||
bool has_key(std::shared_ptr<GradKey> key) const { return m_key == key; } | |||||
bool has_key(const std::shared_ptr<GradKey>& key) const { return m_key == key; } | |||||
const GradSlotPtr& slot_for(std::shared_ptr<GradKey> key) const { | const GradSlotPtr& slot_for(std::shared_ptr<GradKey> key) const { | ||||
mgb_assert(m_key == key); | mgb_assert(m_key == key); | ||||
@@ -287,7 +282,7 @@ public: | |||||
return false; | return false; | ||||
} | } | ||||
std::vector<ValueRef> apply_transformation( | |||||
ValueRefList apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) override; | const Operator& op, Span<ValueRef> inputs) override; | ||||
ValueRef unwrap(ValueRef value) override { | ValueRef unwrap(ValueRef value) override { | ||||
@@ -314,7 +309,7 @@ private: | |||||
public: | public: | ||||
std::string to_string() const override { return "DetachValue"; } | std::string to_string() const override { return "DetachValue"; } | ||||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||||
ValueRefList fallback(Span<ValueRef> inputs) const override { | |||||
return {inputs.as_array<1>()[0]}; | return {inputs.as_array<1>()[0]}; | ||||
} | } | ||||
}; | }; | ||||
@@ -325,7 +320,7 @@ private: | |||||
public: | public: | ||||
AttachGrad(std::shared_ptr<GradKey> key) : m_key(key) {} | AttachGrad(std::shared_ptr<GradKey> key) : m_key(key) {} | ||||
std::shared_ptr<GradKey> key() { return m_key; } | |||||
std::shared_ptr<GradKey> key() const { return m_key; } | |||||
std::string to_string() const override { | std::string to_string() const override { | ||||
return ssprintf("AttachGradValue{key=%s}", m_key->name().c_str()); | return ssprintf("AttachGradValue{key=%s}", m_key->name().c_str()); | ||||
@@ -339,7 +334,7 @@ private: | |||||
public: | public: | ||||
GradBackward(std::shared_ptr<GradKey> key) : m_key(key) {} | GradBackward(std::shared_ptr<GradKey> key) : m_key(key) {} | ||||
std::shared_ptr<GradKey> key() { return m_key; } | |||||
std::shared_ptr<GradKey> key() const { return m_key; } | |||||
std::string to_string() const override { | std::string to_string() const override { | ||||
return ssprintf("GradBackwardValue{key=%s}", m_key->name().c_str()); | return ssprintf("GradBackwardValue{key=%s}", m_key->name().c_str()); | ||||
@@ -352,13 +347,13 @@ private: | |||||
public: | public: | ||||
IsAttachedTo(std::shared_ptr<GradKey> key) : m_key(key) {} | IsAttachedTo(std::shared_ptr<GradKey> key) : m_key(key) {} | ||||
std::shared_ptr<GradKey> key() { return m_key; } | |||||
std::shared_ptr<GradKey> key() const { return m_key; } | |||||
std::string to_string() const override { | std::string to_string() const override { | ||||
return ssprintf("IsAttachedToValue{key=%s}", m_key->name().c_str()); | return ssprintf("IsAttachedToValue{key=%s}", m_key->name().c_str()); | ||||
} | } | ||||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||||
ValueRefList fallback(Span<ValueRef> inputs) const override { | |||||
return {BoolValue::make(false)}; | return {BoolValue::make(false)}; | ||||
} | } | ||||
}; | }; | ||||
@@ -373,9 +368,9 @@ public: | |||||
SetGrad(std::shared_ptr<GradKey> key, GenericFunction grad_fn, size_t nr_inputs) | SetGrad(std::shared_ptr<GradKey> key, GenericFunction grad_fn, size_t nr_inputs) | ||||
: m_key(key), m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | : m_key(key), m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | ||||
GenericFunction grad_fn() { return m_grad_fn; } | |||||
GenericFunction grad_fn() const { return m_grad_fn; } | |||||
size_t nr_inputs() { return m_nr_inputs; } | |||||
size_t nr_inputs() const { return m_nr_inputs; } | |||||
std::string to_string() const override { | std::string to_string() const override { | ||||
return ssprintf("SetGradValue{key=%s}", m_key->name().c_str()); | return ssprintf("SetGradValue{key=%s}", m_key->name().c_str()); | ||||
@@ -388,9 +383,7 @@ public: | |||||
std::string to_string() const override { return ssprintf("GetGradKeyValue{}"); } | std::string to_string() const override { return ssprintf("GetGradKeyValue{}"); } | ||||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||||
return {ValueRef()}; | |||||
} | |||||
ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; } | |||||
}; | }; | ||||
class GetBackwardColsure | class GetBackwardColsure | ||||
@@ -401,7 +394,7 @@ private: | |||||
public: | public: | ||||
GetBackwardColsure(std::shared_ptr<GradKey> key) : m_key(key) {} | GetBackwardColsure(std::shared_ptr<GradKey> key) : m_key(key) {} | ||||
std::shared_ptr<GradKey> key() { return m_key; } | |||||
std::shared_ptr<GradKey> key() const { return m_key; } | |||||
std::string to_string() const override { | std::string to_string() const override { | ||||
return ssprintf("GetBackwardClosure{key=%s}", m_key->name().c_str()); | return ssprintf("GetBackwardClosure{key=%s}", m_key->name().c_str()); | ||||
@@ -81,7 +81,7 @@ public: | |||||
ComputingGraph::Options& options() { return m_graph->options(); } | ComputingGraph::Options& options() { return m_graph->options(); } | ||||
std::vector<ValueRef> apply_transformation( | |||||
ValueRefList apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) override; | const Operator& op, Span<ValueRef> inputs) override; | ||||
ValueRef unwrap(ValueRef value) override { | ValueRef unwrap(ValueRef value) override { | ||||
@@ -11,6 +11,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/imperative/basic_operators.h" | |||||
#include "megbrain/imperative/dispatch.h" | #include "megbrain/imperative/dispatch.h" | ||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
@@ -45,8 +46,10 @@ public: | |||||
*/ | */ | ||||
class ScalarTransformation final : public Transformation { | class ScalarTransformation final : public Transformation { | ||||
private: | private: | ||||
ShapeValue::ref_t m_empty_shape; // [] | |||||
public: | public: | ||||
std::vector<ValueRef> apply_transformation( | |||||
ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs); | |||||
ValueRefList apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) override; | const Operator& op, Span<ValueRef> inputs) override; | ||||
ValueRef unwrap(ValueRef value) override { | ValueRef unwrap(ValueRef value) override { | ||||
@@ -50,7 +50,7 @@ private: | |||||
public: | public: | ||||
SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {} | SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {} | ||||
std::vector<ValueRef> apply_transformation( | |||||
ValueRefList apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) override { | const Operator& op, Span<ValueRef> inputs) override { | ||||
if (auto* apply_op = op.as<ApplyOp>()) { | if (auto* apply_op = op.as<ApplyOp>()) { | ||||
SmallVector<VarNode*> input_nodes; | SmallVector<VarNode*> input_nodes; | ||||
@@ -58,9 +58,9 @@ public: | |||||
input_nodes.push_back(input.cast<SymbolValue>().node()); | input_nodes.push_back(input.cast<SymbolValue>().node()); | ||||
} | } | ||||
auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); | auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); | ||||
std::vector<ValueRef> outputs; | |||||
for (auto&& output_node : output_nodes) { | |||||
outputs.push_back(SymbolValue::make(output_node)); | |||||
ValueRefList outputs(output_nodes.size()); | |||||
for (size_t i = 0; i < output_nodes.size(); ++i) { | |||||
outputs[i] = SymbolValue::make(output_nodes[i]); | |||||
} | } | ||||
return outputs; | return outputs; | ||||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | } else if (auto* create_tensor = op.as<CreateTensor>()) { | ||||
@@ -0,0 +1,36 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/grad.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/imperative/basic_operators.h" | |||||
#include "megbrain/imperative/operator.h" | |||||
#include "megbrain/imperative/transformation.h" | |||||
#include "megbrain/imperative/value.h" | |||||
namespace mgb::imperative { | |||||
struct TangentInfo { | |||||
ValueRef value; | |||||
ValueRef tangent; | |||||
}; | |||||
class TangentTransformation final : public Transformation { | |||||
public: | |||||
ValueRefList apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) override; | |||||
ValueRef unwrap(ValueRef value) override { mgb_assert(false); } | |||||
std::string name() const override { return "Tangent"; } | |||||
}; | |||||
} // namespace mgb::imperative |
@@ -126,25 +126,6 @@ public: | |||||
void on_unwatch() override { value().unwatch(); } | void on_unwatch() override { value().unwatch(); } | ||||
}; | }; | ||||
class TracedInfo { | |||||
private: | |||||
size_t m_id = 0; | |||||
public: | |||||
TracedInfo() = default; | |||||
TracedInfo(size_t id) : m_id(id) {} | |||||
size_t id() const { return m_id; } | |||||
}; | |||||
class TracedValue final : public MixinValueImpl<TracedValue, TracedInfo> { | |||||
public: | |||||
using MixinValueImpl::MixinValueImpl; | |||||
std::string to_string() const override { | |||||
return ssprintf("TracedValue{\"id\"=%zu}", id()); | |||||
} | |||||
}; | |||||
/** | /** | ||||
* \brief trace operation sequence to TraceResult | * \brief trace operation sequence to TraceResult | ||||
* | * | ||||
@@ -202,7 +183,7 @@ public: | |||||
return value; | return value; | ||||
} | } | ||||
std::vector<ValueRef> apply_transformation( | |||||
ValueRefList apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) override; | const Operator& op, Span<ValueRef> inputs) override; | ||||
ValueRef unwrap(ValueRef value) override { | ValueRef unwrap(ValueRef value) override { | ||||
@@ -248,6 +229,40 @@ public: | |||||
std::function<DeviceTensorND()> data_getter; | std::function<DeviceTensorND()> data_getter; | ||||
std::function<HostTensorND()> value_getter; | std::function<HostTensorND()> value_getter; | ||||
std::function<void(DeviceTensorND)> data_setter; | std::function<void(DeviceTensorND)> data_setter; | ||||
std::function<void(std::exception_ptr)> exc_setter; | |||||
}; | |||||
class TracedInfo { | |||||
private: | |||||
size_t m_id = 0; | |||||
VarInfo* m_var = nullptr; | |||||
VarAccessor* m_accessor = nullptr; | |||||
mutable ShapeValue::ref_t m_shape; | |||||
mutable DTypeValue::ref_t m_dtype; | |||||
mutable CompNodeValue::ref_t m_comp_node; | |||||
public: | |||||
TracedInfo() = default; | |||||
TracedInfo(size_t id, VarInfo* var, VarAccessor* accessor) | |||||
: m_id(id), m_var(var), m_accessor(accessor) {} | |||||
size_t id() const { return m_id; } | |||||
ShapeValue::ref_t shape() const; | |||||
DTypeValue::ref_t dtype() const; | |||||
CompNodeValue::ref_t comp_node() const; | |||||
const VarAccessor& accessor() const; | |||||
void set_exception(std::exception_ptr exc) const { | |||||
m_accessor->exc_setter(exc); | |||||
} | |||||
}; | |||||
class TracedValue final : public MixinValueImpl<TracedValue, TracedInfo> { | |||||
public: | |||||
using MixinValueImpl::MixinValueImpl; | |||||
std::string to_string() const override { | |||||
return ssprintf("TracedValue{\"id\"=%zu}", id()); | |||||
} | |||||
}; | }; | ||||
private: | private: | ||||
@@ -319,7 +334,14 @@ public: | |||||
TraceResult::SeqItem& next_instruction(); | TraceResult::SeqItem& next_instruction(); | ||||
std::vector<ValueRef> apply_transformation( | |||||
ValueRefList apply_op(const ApplyOp& apply_op, Span<ValueRef> inputs); | |||||
ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs); | |||||
ValueRefList apply_create_tensor( | |||||
const CreateTensor& create_tensor, Span<ValueRef> inputs); | |||||
ValueRefList apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) override; | const Operator& op, Span<ValueRef> inputs) override; | ||||
void on_unregister() noexcept override; | void on_unregister() noexcept override; | ||||
@@ -36,12 +36,12 @@ private: | |||||
public: | public: | ||||
Allocator(pool_type* pool) : m_pool(pool) {} | Allocator(pool_type* pool) : m_pool(pool) {} | ||||
T* allocate(size_type n) { | |||||
pointer allocate(size_type n) { | |||||
mgb_assert(n == 1); | mgb_assert(n == 1); | ||||
return m_pool->alloc(sizeof(T)); | return m_pool->alloc(sizeof(T)); | ||||
} | } | ||||
void deallocate(pointer* p, size_type n) { | |||||
void deallocate(pointer p, size_type n) { | |||||
mgb_assert(n == 1); | mgb_assert(n == 1); | ||||
m_pool->free(p); | m_pool->free(p); | ||||
} | } | ||||
@@ -68,4 +68,114 @@ public: | |||||
bool operator!=(const ThreadLocalAllocatorAdapter& rhs) const { return false; } | bool operator!=(const ThreadLocalAllocatorAdapter& rhs) const { return false; } | ||||
}; | }; | ||||
} // namespace mgb::imperative | |||||
template <typename T> | |||||
class ForwardAllocator { | |||||
public: | |||||
using value_type = T; | |||||
using size_type = std::size_t; | |||||
using pointer = T*; | |||||
static constexpr size_t alignment = alignof(T); | |||||
static constexpr size_t element_offset = | |||||
sizeof(T) + | |||||
((sizeof(T) % alignment) ? 0 : (alignment - sizeof(T) % alignment)); | |||||
private: | |||||
struct Block { | |||||
std::unique_ptr<std::byte[]> data; | |||||
size_t size = 0; | |||||
size_t capacity = 0; | |||||
T* allocate(size_type n) { | |||||
static_assert(element_offset > std::max(alignment, sizeof(T))); | |||||
size_t begin = size; | |||||
size_t end = begin + element_offset * n; | |||||
if (end > capacity) { | |||||
return nullptr; | |||||
} | |||||
size = end; | |||||
return reinterpret_cast<T*>(data.get() + begin); | |||||
} | |||||
void reset() { size = 0; } | |||||
}; | |||||
std::vector<Block> m_used; | |||||
std::optional<Block> m_current; | |||||
size_t block_size = 16 * 1024 * 1024; | |||||
size_t nr_allocated = 0; | |||||
private: | |||||
Block allocate_block() { | |||||
block_size *= 2; | |||||
return Block{std::make_unique<std::byte[]>(block_size), 0, block_size}; | |||||
} | |||||
public: | |||||
pointer allocate(size_type n) { | |||||
if (!m_current) { | |||||
m_current.emplace(allocate_block()); | |||||
} | |||||
pointer pointer = m_current->allocate(n); | |||||
while (pointer == nullptr) { | |||||
m_used.push_back(allocate_block()); | |||||
std::swap(m_used.back(), *m_current); | |||||
pointer = m_current->allocate(n); | |||||
} | |||||
nr_allocated++; | |||||
return pointer; | |||||
} | |||||
void deallocate(pointer p, size_type n) { | |||||
mgb_assert(nr_allocated > 0); | |||||
nr_allocated--; | |||||
} | |||||
void clear() { | |||||
if (mgb_likely(m_used.empty())) { | |||||
// fastpath | |||||
if (m_current) { | |||||
m_current->reset(); | |||||
} | |||||
} else { | |||||
// trim | |||||
*m_current = allocate_block(); | |||||
m_used.clear(); | |||||
} | |||||
mgb_assert(nr_allocated == 0); | |||||
} | |||||
bool operator==(const ForwardAllocator& rhs) const { return &rhs == this; } | |||||
bool operator!=(const ForwardAllocator& rhs) const { return &rhs != this; } | |||||
}; | |||||
template <typename T, template <typename> typename TAllocator> | |||||
class ProxyAllocator { | |||||
public: | |||||
using value_type = T; | |||||
using size_type = typename TAllocator<T>::size_type; | |||||
using pointer = typename TAllocator<T>::pointer; | |||||
private: | |||||
TAllocator<T>* m_impl; | |||||
public: | |||||
T* allocate(size_type n) { return m_impl->allocate(n); } | |||||
void deallocate(pointer* p, size_type n) { return m_impl->deallocate(p, n); } | |||||
bool operator==(const ProxyAllocator<T, TAllocator>& rhs) const { | |||||
if (m_impl == rhs.m_impl) { | |||||
return true; | |||||
} else if (bool(m_impl) ^ bool(rhs.m_impl)) { | |||||
return false; | |||||
} else { | |||||
return *m_impl == *rhs.m_impl; | |||||
} | |||||
} | |||||
bool operator!=(const ProxyAllocator<T, TAllocator>& rhs) const { | |||||
return !((*this) == rhs); | |||||
} | |||||
}; | |||||
} // namespace mgb::imperative |
@@ -16,6 +16,8 @@ | |||||
#include "megbrain/imperative/utils/mempool.h" | #include "megbrain/imperative/utils/mempool.h" | ||||
#include "megbrain/utils/metahelper.h" | #include "megbrain/utils/metahelper.h" | ||||
#define MGB_FAT_LOCAL_PTR 0 | |||||
namespace mgb::imperative { | namespace mgb::imperative { | ||||
template <typename T> | template <typename T> | ||||
@@ -52,6 +54,8 @@ private: | |||||
} | } | ||||
} | } | ||||
size_t ref_count() const { return m_ref_count; } | |||||
template <typename U> | template <typename U> | ||||
friend class LocalPtr; | friend class LocalPtr; | ||||
@@ -88,14 +92,24 @@ public: | |||||
using storage_t = LocalPtrStorage<T>; | using storage_t = LocalPtrStorage<T>; | ||||
using pool_t = MemPool<storage_t>; | using pool_t = MemPool<storage_t>; | ||||
using weak_type = LocalWeakPtr<T>; | using weak_type = LocalWeakPtr<T>; | ||||
using pointer_t = T*; | |||||
private: | private: | ||||
storage_t* m_storage = nullptr; | storage_t* m_storage = nullptr; | ||||
#if MGB_FAT_LOCAL_PTR | |||||
pointer_t m_pointer = nullptr; | |||||
#endif | |||||
// (m_storage == nullptr) == (m_pointer == nullptr) | |||||
void emplace(storage_t* ptr) { | void emplace(storage_t* ptr) { | ||||
if (ptr) { | if (ptr) { | ||||
ptr->inc_ref(); | ptr->inc_ref(); | ||||
m_storage = ptr; | m_storage = ptr; | ||||
#if MGB_FAT_LOCAL_PTR | |||||
m_pointer = ptr->m_pointer; | |||||
#endif | |||||
} | } | ||||
} | } | ||||
@@ -103,8 +117,22 @@ private: | |||||
public: | public: | ||||
LocalPtr() = default; | LocalPtr() = default; | ||||
LocalPtr(const LocalPtr& rhs) { (*this) = rhs; } | |||||
LocalPtr(LocalPtr&& rhs) { (*this) = std::move(rhs); } | |||||
LocalPtr(const LocalPtr& rhs) { | |||||
auto storage = rhs.m_storage; | |||||
if (storage) { | |||||
storage->inc_ref(); | |||||
} | |||||
m_storage = storage; | |||||
#if MGB_FAT_LOCAL_PTR | |||||
m_pointer = rhs.m_pointer; | |||||
#endif | |||||
} | |||||
LocalPtr(LocalPtr&& rhs) { | |||||
std::swap(m_storage, rhs.m_storage); | |||||
#if MGB_FAT_LOCAL_PTR | |||||
std::swap(m_pointer, rhs.m_pointer); | |||||
#endif | |||||
} | |||||
LocalPtr& operator=(const LocalPtr& rhs) { | LocalPtr& operator=(const LocalPtr& rhs) { | ||||
if (this == &rhs) { | if (this == &rhs) { | ||||
return *this; | return *this; | ||||
@@ -115,9 +143,11 @@ public: | |||||
} | } | ||||
if (m_storage) { | if (m_storage) { | ||||
m_storage->dec_ref(); | m_storage->dec_ref(); | ||||
// rhs.m_storage may be invalid here | |||||
} | } | ||||
m_storage = storage; | m_storage = storage; | ||||
#if MGB_FAT_LOCAL_PTR | |||||
m_pointer = rhs.m_pointer; | |||||
#endif | |||||
return *this; | return *this; | ||||
} | } | ||||
LocalPtr& operator=(LocalPtr&& rhs) { | LocalPtr& operator=(LocalPtr&& rhs) { | ||||
@@ -125,6 +155,9 @@ public: | |||||
return *this; | return *this; | ||||
} | } | ||||
std::swap(m_storage, rhs.m_storage); | std::swap(m_storage, rhs.m_storage); | ||||
#if MGB_FAT_LOCAL_PTR | |||||
std::swap(m_pointer, rhs.m_pointer); | |||||
#endif | |||||
rhs.reset(); | rhs.reset(); | ||||
return *this; | return *this; | ||||
} | } | ||||
@@ -186,10 +219,11 @@ public: | |||||
T& operator*() const { return *get(); } | T& operator*() const { return *get(); } | ||||
T* get() const { | T* get() const { | ||||
if ((!m_storage) || !m_storage->m_pointer) { | |||||
return nullptr; | |||||
} | |||||
return m_storage->m_pointer; | |||||
#if MGB_FAT_LOCAL_PTR | |||||
return m_pointer; | |||||
#else | |||||
return m_storage ? m_storage->m_pointer : nullptr; | |||||
#endif | |||||
} | } | ||||
T* operator->() const { return get(); } | T* operator->() const { return get(); } | ||||
@@ -202,6 +236,9 @@ public: | |||||
if (m_storage) { | if (m_storage) { | ||||
m_storage->dec_ref(); | m_storage->dec_ref(); | ||||
m_storage = nullptr; | m_storage = nullptr; | ||||
#if MGB_FAT_LOCAL_PTR | |||||
m_pointer = nullptr; | |||||
#endif | |||||
} | } | ||||
} | } | ||||
@@ -49,8 +49,8 @@ public: | |||||
instance = std::make_unique<MemPool<T>>(); | instance = std::make_unique<MemPool<T>>(); | ||||
sm_instance = instance.get(); | sm_instance = instance.get(); | ||||
} | } | ||||
mgb_assert(sm_instance); | |||||
} | } | ||||
return *sm_instance; | |||||
} | } | ||||
}; | }; | ||||
@@ -62,9 +62,9 @@ std::unordered_map<std::thread::id, std::unique_ptr<MemPool<T>>> | |||||
MemPoolUtils<T>::sm_instances; | MemPoolUtils<T>::sm_instances; | ||||
template <typename T> | template <typename T> | ||||
thread_local MemPool<T>* MemPoolUtils<T>::tm_instance; | |||||
thread_local MemPool<T>* MemPoolUtils<T>::tm_instance = nullptr; | |||||
template <typename T> | template <typename T> | ||||
MemPool<T>* MemPoolUtils<T>::sm_instance; | |||||
MemPool<T>* MemPoolUtils<T>::sm_instance = nullptr; | |||||
} // namespace mgb::imperative | |||||
} // namespace mgb::imperative |
@@ -95,6 +95,8 @@ struct ValueShape { | |||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
bool operator!=(const ValueShape& rhs) const { return !operator==(rhs); } | |||||
}; | }; | ||||
static_assert(sizeof(size_t) >= sizeof(int)); | static_assert(sizeof(size_t) >= sizeof(int)); | ||||
@@ -47,6 +47,17 @@ class StringValue; | |||||
class Operator; | class Operator; | ||||
class ValueRefList; | |||||
template <typename T> | |||||
class Type { | |||||
private: | |||||
const size_t m_code = T::TYPE_CODE; | |||||
public: | |||||
inline size_t code() const { return m_code; } | |||||
}; | |||||
/** | /** | ||||
* \brief an smart reference of value | * \brief an smart reference of value | ||||
* | * | ||||
@@ -64,8 +75,9 @@ public: | |||||
protected: | protected: | ||||
mutable storage_t m_storage; | mutable storage_t m_storage; | ||||
size_t m_id = std::numeric_limits<size_t>::max(); | |||||
ValueRef(storage_t storage) { m_storage = storage; } | |||||
inline ValueRef(storage_t storage); | |||||
private: | private: | ||||
/** | /** | ||||
@@ -75,6 +87,10 @@ private: | |||||
*/ | */ | ||||
storage_t& storage() const; | storage_t& storage() const; | ||||
const Value* as(size_t typecode) const; | |||||
bool is(size_t typecode) const; | |||||
public: | public: | ||||
ValueRef() = default; | ValueRef() = default; | ||||
@@ -86,7 +102,7 @@ public: | |||||
* \return false if empty or type of value is not TValue | * \return false if empty or type of value is not TValue | ||||
*/ | */ | ||||
template <typename TValue> | template <typename TValue> | ||||
bool is() const; | |||||
inline bool is(Type<TValue> type = {}) const; | |||||
/** | /** | ||||
* \brief try cast value as target type | * \brief try cast value as target type | ||||
@@ -95,7 +111,7 @@ public: | |||||
* \return TValue* raw pointer if success, otherwise nullptr | * \return TValue* raw pointer if success, otherwise nullptr | ||||
*/ | */ | ||||
template <typename TValue> | template <typename TValue> | ||||
const TValue* as() const; | |||||
inline const TValue* as(Type<TValue> type = {}) const; | |||||
/** | /** | ||||
* \brief cast value to target type | * \brief cast value to target type | ||||
@@ -104,7 +120,7 @@ public: | |||||
* \return TValue& reference of value | * \return TValue& reference of value | ||||
*/ | */ | ||||
template <typename TValue> | template <typename TValue> | ||||
const TValue& cast() const; | |||||
inline const TValue& cast(Type<TValue> type = {}) const; | |||||
/** | /** | ||||
* \brief like as(), but returns TypedValueRef instead | * \brief like as(), but returns TypedValueRef instead | ||||
@@ -113,7 +129,13 @@ public: | |||||
* \return TypedValueRef<TValue> reference if success, otherwise empty reference | * \return TypedValueRef<TValue> reference if success, otherwise empty reference | ||||
*/ | */ | ||||
template <typename TValue> | template <typename TValue> | ||||
inline TypedValueRef<TValue> as_ref() const; | |||||
inline TypedValueRef<TValue> as_ref(Type<TValue> type = {}) const; | |||||
template <typename TValue> | |||||
inline TypedValueRef<TValue> cast_ref(Type<TValue> type = {}) const; | |||||
template <typename TValue> | |||||
void on_cast_failure() const; | |||||
operator bool() const { return bool(m_storage); } | operator bool() const { return bool(m_storage); } | ||||
@@ -132,7 +154,7 @@ public: | |||||
ValueRef unwrap() const; | ValueRef unwrap() const; | ||||
std::string to_string() const; | std::string to_string() const; | ||||
std::string raw_type() const; | std::string raw_type() const; | ||||
uint64_t id() const; | |||||
uint64_t id() const { return m_id; } | |||||
size_t hash() const { return id(); } | size_t hash() const { return id(); } | ||||
static ValueRef make(storage_t storage); | static ValueRef make(storage_t storage); | ||||
@@ -144,7 +166,7 @@ public: | |||||
friend class TypedValueRef; | friend class TypedValueRef; | ||||
template <typename T> | template <typename T> | ||||
friend class ValueImpl; | friend class ValueImpl; | ||||
friend std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs); | |||||
friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs); | |||||
}; | }; | ||||
template <> | template <> | ||||
@@ -244,7 +266,7 @@ public: | |||||
using ref_t = TypedValueRef<T>; | using ref_t = TypedValueRef<T>; | ||||
using weak_ref_t = TypedValueWeakRef<T>; | using weak_ref_t = TypedValueWeakRef<T>; | ||||
static inline size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); | |||||
static inline const size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); | |||||
/** | /** | ||||
* \brief helper function for construct a value | * \brief helper function for construct a value | ||||
@@ -254,7 +276,7 @@ public: | |||||
* \return TypedValueRef<T> reference of value | * \return TypedValueRef<T> reference of value | ||||
*/ | */ | ||||
template <typename... TArgs> | template <typename... TArgs> | ||||
static TypedValueRef<T> make(TArgs&&... args) { | |||||
static MGB_NOINLINE TypedValueRef<T> make(TArgs&&... args) { | |||||
static_assert(std::is_final_v<T>); | static_assert(std::is_final_v<T>); | ||||
return ValueRef::make(LocalPtr<Value>::make<T>(std::forward<TArgs&&>(args)...)); | return ValueRef::make(LocalPtr<Value>::make<T>(std::forward<TArgs&&>(args)...)); | ||||
} | } | ||||
@@ -279,46 +301,60 @@ public: | |||||
bool eq(const TMixin& value) const { return ((const TMixin&)*this) == value; } | bool eq(const TMixin& value) const { return ((const TMixin&)*this) == value; } | ||||
}; | }; | ||||
inline ValueRef::ValueRef(storage_t storage) { | |||||
// mgb_assert(storage); | |||||
m_storage = storage; | |||||
m_id = m_storage->m_id; | |||||
} | |||||
template <typename TValue> | template <typename TValue> | ||||
const TValue* ValueRef::as() const { | |||||
inline const TValue* ValueRef::as(Type<TValue> type) const { | |||||
static_assert(std::is_base_of_v<ValueImpl<TValue>, TValue>); | static_assert(std::is_base_of_v<ValueImpl<TValue>, TValue>); | ||||
auto storage = this->storage(); | |||||
if (!storage) { | |||||
return nullptr; | |||||
} | |||||
if (storage->m_typecode != TValue::TYPE_CODE) { | |||||
return nullptr; | |||||
} | |||||
return static_cast<TValue*>(storage.get()); | |||||
return static_cast<const TValue*>(as(type.code())); | |||||
} | } | ||||
template <typename TValue> | template <typename TValue> | ||||
const TValue& ValueRef::cast() const { | |||||
auto* ptr = as<TValue>(); | |||||
if (!ptr) { | |||||
// if this is ErrorValue, rethrow directly | |||||
storage()->try_rethrow(); | |||||
mgb_assert( | |||||
ptr, "expect type %s, got %s", typeid(TValue).name(), | |||||
to_string().c_str()); | |||||
inline const TValue& ValueRef::cast(Type<TValue> type) const { | |||||
auto* ptr = as<TValue>(type); | |||||
if (mgb_unlikely(!ptr)) { | |||||
on_cast_failure<TValue>(); | |||||
} | } | ||||
return *ptr; | |||||
return static_cast<const TValue&>(*ptr); | |||||
} | |||||
template <typename TValue> | |||||
inline bool ValueRef::is(Type<TValue> type) const { | |||||
return is(type.code()); | |||||
} | } | ||||
template <typename TValue> | template <typename TValue> | ||||
bool ValueRef::is() const { | |||||
auto* ptr = as<TValue>(); | |||||
return ptr != nullptr; | |||||
inline TypedValueRef<TValue> ValueRef::as_ref(Type<TValue> type) const { | |||||
if (!is<TValue>(type)) { | |||||
return {}; | |||||
} | |||||
return TypedValueRef<TValue>(*this); | |||||
} | } | ||||
template <typename TValue> | template <typename TValue> | ||||
TypedValueRef<TValue> ValueRef::as_ref() const { | |||||
if (!is<TValue>()) { | |||||
inline TypedValueRef<TValue> ValueRef::cast_ref(Type<TValue> type) const { | |||||
if (!m_storage) { | |||||
return {}; | return {}; | ||||
} | } | ||||
if (mgb_unlikely(!is<TValue>(type))) { | |||||
on_cast_failure<TValue>(); | |||||
} | |||||
return TypedValueRef<TValue>(*this); | return TypedValueRef<TValue>(*this); | ||||
} | } | ||||
template <typename TValue> | |||||
void ValueRef::on_cast_failure() const { | |||||
// if this is ErrorValue, rethrow directly | |||||
storage()->try_rethrow(); | |||||
mgb_assert( | |||||
storage()->m_typecode != TValue::TYPE_CODE, "expect type %s, got %s", | |||||
typeid(TValue).name(), to_string().c_str()); | |||||
} | |||||
/** | /** | ||||
* \brief ValueRef with concrete type, convenient for dereference | * \brief ValueRef with concrete type, convenient for dereference | ||||
* | * | ||||
@@ -361,11 +397,87 @@ private: | |||||
public: | public: | ||||
TypedValueWeakRef(ValueRef value) : ValueWeakRef(value) {} | TypedValueWeakRef(ValueRef value) : ValueWeakRef(value) {} | ||||
TypedValueWeakRef(ValueWeakRef value) : ValueWeakRef(value) {} | TypedValueWeakRef(ValueWeakRef value) : ValueWeakRef(value) {} | ||||
TypedValueRef<T> lock() { return ValueWeakRef::lock().template as_ref<T>(); } | |||||
TypedValueRef<T> lock() { | |||||
auto value = ValueWeakRef::lock(); | |||||
if (value) { | |||||
return value.template as_ref<T>(); | |||||
} else { | |||||
return {}; | |||||
} | |||||
} | |||||
}; | }; | ||||
// TODO: add proxy value type, which is meant to be reset in the end | // TODO: add proxy value type, which is meant to be reset in the end | ||||
class ValueRefList { | |||||
private: | |||||
ValueRef* m_data = nullptr; | |||||
size_t m_size = 0; | |||||
std::aligned_storage_t<sizeof(ValueRef), alignof(ValueRef)> m_storage; | |||||
private: | |||||
void init(size_t nr_elems); | |||||
ValueRef* inline_storage() { return reinterpret_cast<ValueRef*>(&m_storage); } | |||||
public: | |||||
ValueRefList() = default; | |||||
ValueRefList(size_t nr_elems); | |||||
ValueRefList(ValueRef item); | |||||
ValueRefList(std::initializer_list<ValueRef> values); | |||||
template <typename TIterator> | |||||
ValueRefList(TIterator begin, TIterator end); | |||||
ValueRefList(const ValueRefList& rhs); | |||||
ValueRefList(ValueRefList&& rhs); | |||||
ValueRefList& operator=(const ValueRefList& rhs); | |||||
ValueRefList& operator=(ValueRefList&& rhs); | |||||
~ValueRefList(); | |||||
void clear(); | |||||
ValueRef* begin() { return m_data; } | |||||
ValueRef* end() { return m_data + m_size; } | |||||
const ValueRef* cbegin() const { return m_data; } | |||||
const ValueRef* cend() const { return m_data + m_size; } | |||||
size_t size() const { return m_size; } | |||||
ValueRef& at(size_t idx) { | |||||
mgb_assert(idx < m_size); | |||||
return m_data[idx]; | |||||
} | |||||
const ValueRef& at(size_t idx) const { | |||||
mgb_assert(idx < m_size); | |||||
return m_data[idx]; | |||||
} | |||||
ValueRef& operator[](size_t idx) { return m_data[idx]; } | |||||
const ValueRef& operator[](size_t idx) const { return m_data[idx]; } | |||||
ValueRef* data() { return m_data; } | |||||
const ValueRef* data() const { return m_data; } | |||||
bool empty() const { return m_size == 0; } | |||||
ValueRef& front() { | |||||
mgb_assert(m_size > 1); | |||||
return m_data[0]; | |||||
} | |||||
ValueRef& back() { | |||||
mgb_assert(m_size > 1); | |||||
return m_data[m_size - 1]; | |||||
} | |||||
}; | |||||
template <typename TIterator> | |||||
ValueRefList::ValueRefList(TIterator begin, TIterator end) : ValueRefList(end - begin) { | |||||
for (size_t i = 0; i < m_size; ++i) { | |||||
m_data[i] = *(begin + i); | |||||
} | |||||
} | |||||
inline ValueRefList::ValueRefList(ValueRef item) : m_data(inline_storage()), m_size(1) { | |||||
new (m_data) ValueRef(); | |||||
m_data[0] = std::move(item); | |||||
} | |||||
/*class ValueRefList : public SmallVector<ValueRef, 1> { | |||||
public: | |||||
using SmallVector::SmallVector; | |||||
};*/ | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb | ||||