@@ -16,6 +16,7 @@ import numpy as np | |||
from .. import _config | |||
from .._imperative_rt.common import CompNode | |||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion | |||
from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | |||
from ..ops import builtin | |||
from . import amp | |||
from .indexing import getitem, setitem | |||
@@ -508,12 +509,8 @@ def _reduce(mode): | |||
elif self.dtype == np.bool_: | |||
data = data.astype("int32") | |||
if axis is None: | |||
data = data.reshape(-1) | |||
assert not keepdims, "can not set axis=None and keepdims=True" | |||
op = builtin.Reduce(mode=mode, axis=0) | |||
(result,) = apply(op, data) | |||
result = _remove_axis(result, 0) | |||
result = _reduce_to_scalar(builtin.Reduce(mode=mode), data) | |||
elif isinstance(axis, collections.abc.Iterable): | |||
axis = _normalize_axis(self.ndim, axis, reverse=True) | |||
for ai in axis: | |||
@@ -69,7 +69,7 @@ class SGD(Optimizer): | |||
inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")) | |||
if inplace_mode: | |||
_neg_lr = tensor(-lr, dtype="float32") | |||
c1 = tensor([1.0]) | |||
c1 = tensor(1.0) | |||
for param in param_group["params"]: | |||
if param.grad is None: | |||
@@ -84,14 +84,15 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
device: str = None, | |||
is_const: bool = False, | |||
no_cache: bool = False, | |||
name: str = "", | |||
name: str = None, | |||
): | |||
if name is None: | |||
name = "" | |||
else: | |||
self._set_name(name) | |||
self._custom_name = name | |||
self._name = name | |||
self._short_name = name | |||
self._set_name(self._name) | |||
self._prefix = None | |||
@property | |||
@@ -46,17 +46,17 @@ void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { | |||
if (args[1] != Py_None) { | |||
callback = py::reinterpret_borrow<py::object>(args[1]); | |||
} | |||
GenericFunction generic_callback = | |||
[=](Span<ValueRef> inputs) -> std::vector<ValueRef> { | |||
GenericFunction generic_callback = [=](Span<ValueRef> inputs) -> ValueRefList { | |||
mgb_assert(inputs.size() == 1); | |||
if (callback) { | |||
callback(TensorWrapper::make(py_tensor_type, inputs[0])); | |||
} | |||
return {}; | |||
}; | |||
tw->m_tensor->reset(imperative::apply( | |||
auto attached_value = imperative::apply( | |||
AttachGrad(m_key), tw->m_tensor->data(), | |||
FunctionValue::make(generic_callback))[0]); | |||
FunctionValue::make(generic_callback))[0]; | |||
tw->m_tensor->reset(attached_value); | |||
} | |||
void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) { | |||
@@ -98,7 +98,7 @@ ValueRef make_empty_tensor( | |||
return res; | |||
} | |||
std::optional<std::vector<ValueRef>> elemwise_grad_rule( | |||
std::optional<ValueRefList> elemwise_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto& elemwise = op.cast_final_safe<Elemwise>(); | |||
@@ -117,7 +117,7 @@ std::optional<std::vector<ValueRef>> elemwise_grad_rule( | |||
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(2); | |||
ValueRefList ret(2); | |||
if (!grad) { | |||
return ret; | |||
} | |||
@@ -132,7 +132,7 @@ std::optional<std::vector<ValueRef>> elemwise_grad_rule( | |||
return imperative::apply(ApplyOp(op), inputs); | |||
} | |||
std::optional<std::vector<ValueRef>> reshape_grad_rule( | |||
std::optional<ValueRefList> reshape_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
mgb_assert(inputs.size() == 2); | |||
@@ -147,7 +147,7 @@ std::optional<std::vector<ValueRef>> reshape_grad_rule( | |||
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(2); | |||
ValueRefList ret(2); | |||
if (!grad) { | |||
return ret; | |||
} | |||
@@ -162,7 +162,7 @@ std::optional<std::vector<ValueRef>> reshape_grad_rule( | |||
return imperative::apply(ApplyOp(op), inputs); | |||
} | |||
std::optional<std::vector<ValueRef>> subtensor_grad_rule( | |||
std::optional<ValueRefList> subtensor_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto&& subtensor = op.cast_final_safe<Subtensor>(); | |||
@@ -180,9 +180,9 @@ std::optional<std::vector<ValueRef>> subtensor_grad_rule( | |||
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(1); | |||
ValueRefList ret(1); | |||
if (grad && inputs[0]) { | |||
SmallVector<ValueRef> args_(inputs.size() + 1); | |||
ValueRefList args_(inputs.size() + 1); | |||
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | |||
args_[0] = zeros; | |||
args_[1] = grad; | |||
@@ -197,7 +197,7 @@ std::optional<std::vector<ValueRef>> subtensor_grad_rule( | |||
return imperative::apply(ApplyOp(op), inputs); | |||
} | |||
std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule( | |||
std::optional<ValueRefList> indexingMultiAxisVec_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>(); | |||
@@ -215,9 +215,9 @@ std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule( | |||
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(1); | |||
ValueRefList ret(1); | |||
if (grad && inputs[0]) { | |||
SmallVector<ValueRef> args_(inputs.size() + 1); | |||
ValueRefList args_(inputs.size() + 1); | |||
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | |||
args_[0] = zeros; | |||
args_[1] = grad; | |||
@@ -232,7 +232,7 @@ std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule( | |||
return imperative::apply(ApplyOp(op), inputs); | |||
} | |||
std::optional<std::vector<ValueRef>> reduce_grad_rule( | |||
std::optional<ValueRefList> reduce_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto& reduce = op.cast_final_safe<Reduce>(); | |||
@@ -251,7 +251,7 @@ std::optional<std::vector<ValueRef>> reduce_grad_rule( | |||
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(1); | |||
ValueRefList ret(1); | |||
if (grad && shapes[0]) { | |||
ret[0] = broadcast_to(grad, shapes[0]); | |||
} | |||
@@ -261,7 +261,7 @@ std::optional<std::vector<ValueRef>> reduce_grad_rule( | |||
return imperative::apply(ApplyOp(op), inputs); | |||
} | |||
std::optional<std::vector<ValueRef>> addAxis_grad_rule( | |||
std::optional<ValueRefList> addAxis_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto&& addAxis = op.cast_final_safe<AddAxis>(); | |||
@@ -274,7 +274,7 @@ std::optional<std::vector<ValueRef>> addAxis_grad_rule( | |||
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(1); | |||
ValueRefList ret(1); | |||
if (grad && flag_) { | |||
ret[0] = imperative::apply(*grad_op_, grad)[0]; | |||
} | |||
@@ -284,7 +284,7 @@ std::optional<std::vector<ValueRef>> addAxis_grad_rule( | |||
return imperative::apply(op, inputs); | |||
} | |||
std::optional<std::vector<ValueRef>> removeAxis_grad_rule( | |||
std::optional<ValueRefList> removeAxis_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto&& removeAxis = op.cast_final_safe<RemoveAxis>(); | |||
@@ -297,7 +297,7 @@ std::optional<std::vector<ValueRef>> removeAxis_grad_rule( | |||
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(1); | |||
ValueRefList ret(1); | |||
if (grad && flag_) { | |||
ret[0] = imperative::apply(*grad_op_, grad)[0]; | |||
} | |||
@@ -307,7 +307,7 @@ std::optional<std::vector<ValueRef>> removeAxis_grad_rule( | |||
return imperative::apply(op, inputs); | |||
} | |||
std::optional<std::vector<ValueRef>> fastpathcopy_grad_rule( | |||
std::optional<ValueRefList> fastpathcopy_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
mgb_assert(inputs.size() == 1); | |||
@@ -316,7 +316,7 @@ std::optional<std::vector<ValueRef>> fastpathcopy_grad_rule( | |||
maker.backward([](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
std::vector<ValueRef> ret(1); | |||
ValueRefList ret(1); | |||
if (grad) { | |||
ret[0] = grad; | |||
} | |||
@@ -25,24 +25,23 @@ private: | |||
py::function m_hook_fn; | |||
int m_enabled = 0; | |||
std::vector<ValueRef> apply_module_trace_hook( | |||
const OpDef& op, Span<ValueRef> input_values) { | |||
ValueRefList apply_module_trace_hook(const OpDef& op, Span<ValueRef> input_values) { | |||
py::list input_tws; | |||
for (auto&& input_value : input_values) { | |||
input_tws.append(TensorWrapper::make(py_tensor_type, input_value)); | |||
} | |||
py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws); | |||
std::vector<ValueRef> outputs; | |||
ValueRefList outputs(output_tws.size()); | |||
auto it = outputs.begin(); | |||
for (auto&& output_tw : output_tws) { | |||
outputs.push_back( | |||
TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data()); | |||
*(it++) = TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data(); | |||
} | |||
return outputs; | |||
} | |||
public: | |||
ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} | |||
std::vector<ValueRef> apply_transformation( | |||
ValueRefList apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) override { | |||
if (op.is<ApplyOp>() && m_enabled > 0) { | |||
auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs); | |||
@@ -87,7 +87,7 @@ PyObject* py_apply( | |||
--nargs; | |||
auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>(); | |||
SmallVector<ValueRef, 64> tensors(nargs); | |||
SmallVector<ValueRef, 8> tensors(nargs); | |||
if (py::isinstance<PySymbolVar>(py::handle(args[0]))) { | |||
// swap to a special context to reuse scalar handle | |||
@@ -100,16 +100,15 @@ PyObject* py_apply( | |||
Transformation::top()); | |||
std::make_shared<ScalarTransformation>()->register_at( | |||
Transformation::top()); | |||
SmallVector<ValueRef> inputs(nargs); | |||
for (size_t i = 0; i < nargs; ++i) { | |||
auto* py_input = py::handle(args[i]).cast<PySymbolVar*>(); | |||
ValueRef input = SymbolValue::make(py_input->m_node); | |||
if (py_input->is_scalar) { | |||
input = ScalarValue::make(input); | |||
} | |||
inputs[i] = input; | |||
tensors[i] = input; | |||
} | |||
auto outputs = imperative::apply(*op, inputs); | |||
auto outputs = imperative::apply(*op, tensors); | |||
auto ret = pybind11::tuple(outputs.size()); | |||
auto typeobj = py::handle(args[0]).get_type(); | |||
for (size_t i = 0; i < outputs.size(); ++i) { | |||
@@ -140,7 +139,7 @@ PyObject* py_apply( | |||
} | |||
} | |||
auto outputs = imperative::apply(ApplyOp(*op), {tensors.data(), nargs}); | |||
auto outputs = imperative::apply(*op, tensors); | |||
size_t nout = outputs.size(); | |||
auto ret = py::tuple(nout); | |||
for (size_t i = 0; i < nout; ++i) { | |||
@@ -214,16 +213,10 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
if (!name.empty()) { | |||
m_tensor->reset( | |||
imperative::apply(RenameValue(name), m_tensor->data())[0]); | |||
mgb_assert( | |||
((std::string&)*m_tensor->data().name()) == name, | |||
"result name incorrect"); | |||
} | |||
if (data.ndim() == 0) { | |||
mgb_assert(m_tensor->is_scalar(), "result should be scalar"); | |||
} | |||
} | |||
} | |||
mgb_assert(m_tensor->data()); | |||
} | |||
PyObject* TensorWrapper::module_trace_info() { | |||
@@ -1384,15 +1377,20 @@ void init_tensor(py::module m) { | |||
std::function<bool(py::object, py::object)> array_comparator; | |||
bool compare_value(ValueRef lhs, ValueRef rhs) { | |||
if (!lhs.shape()->eq(*rhs.shape())) { | |||
auto lvalue = lhs.numpy(); | |||
auto rvalue = rhs.numpy(); | |||
if (lvalue->shape() != rvalue->shape()) { | |||
return false; | |||
} | |||
HostTensorND lvalue = lhs.numpy()->as_nd(true); | |||
HostTensorND rvalue = rhs.numpy()->as_nd(true); | |||
if (lvalue->shape().is_scalar()) { | |||
return lvalue->item() == rvalue->item(); | |||
} | |||
HostTensorND lnd = lvalue->as_nd(true); | |||
HostTensorND rnd = rvalue->as_nd(true); | |||
auto larr = py::reinterpret_steal<py::array>( | |||
npy::ndarray_from_tensor(lvalue, npy::ShareType::TRY_SHARE)); | |||
npy::ndarray_from_tensor(lnd, npy::ShareType::TRY_SHARE)); | |||
auto rarr = py::reinterpret_steal<py::array>( | |||
npy::ndarray_from_tensor(rvalue, npy::ShareType::TRY_SHARE)); | |||
npy::ndarray_from_tensor(rnd, npy::ShareType::TRY_SHARE)); | |||
return array_comparator(larr, rarr); | |||
} | |||
@@ -1539,6 +1537,19 @@ void init_tensor(py::module m) { | |||
} | |||
}); | |||
m.def("reduce_to_scalar", [](py::object op, py::object tensor) { | |||
auto* tw = TensorWrapper::try_cast(tensor.ptr()); | |||
auto make_scalar_shape = [&](CompNode device) { | |||
return imperative::apply( | |||
CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}), | |||
HostStorage::make(device))[0]; | |||
}; | |||
auto output = imperative::apply( | |||
*op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data(), | |||
make_scalar_shape(tw->m_tensor->comp_node()))[0]; | |||
return TensorWrapper::make(py_tensor_type, output); | |||
}); | |||
m.def("name_tensor", [](std::string name, py::object tensor) { | |||
auto* tw = TensorWrapper::try_cast(tensor.ptr()); | |||
auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0]; | |||
@@ -1546,9 +1557,9 @@ void init_tensor(py::module m) { | |||
}); | |||
m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool { | |||
SmallVector<ValueRef> values; | |||
for (auto&& tensor : tensors) { | |||
values.push_back(tensor.cast<TensorWrapper>().m_tensor->data()); | |||
ValueRefList values(tensors.size()); | |||
for (size_t i = 0; i < tensors.size(); ++i) { | |||
values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data(); | |||
} | |||
auto outputs = imperative::apply(GetGradKey(), values); | |||
if (outputs[0].is<GradKeyValue>()) { | |||
@@ -1559,9 +1570,9 @@ void init_tensor(py::module m) { | |||
}); | |||
m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object { | |||
SmallVector<ValueRef> values; | |||
for (auto&& tensor : tensors) { | |||
values.push_back(tensor.cast<TensorWrapper>().m_tensor->data()); | |||
ValueRefList values(tensors.size()); | |||
for (size_t i = 0; i < tensors.size(); ++i) { | |||
values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data(); | |||
} | |||
auto outputs = imperative::apply(GetGradKey(), values); | |||
if (auto* grad_key_val = outputs[0].as<GradKeyValue>()) { | |||
@@ -1578,7 +1589,7 @@ void init_tensor(py::module m) { | |||
mgb_assert(GradKeyWrapper::wrap_t::type().isinstance(py_key.ptr())); | |||
auto* key = reinterpret_cast<GradKeyWrapper::wrap_t*>(py_key.ptr())->inst(); | |||
GenericFunction generic_backward_fn = | |||
[backward_fn](Span<ValueRef> output_grads) -> std::vector<ValueRef> { | |||
[backward_fn](Span<ValueRef> output_grads) -> ValueRefList { | |||
py::list output_grad_tws; | |||
for (auto&& output_grad : output_grads) { | |||
if (output_grad) { | |||
@@ -1589,23 +1600,25 @@ void init_tensor(py::module m) { | |||
} | |||
} | |||
py::tuple input_grad_tws = backward_fn(*output_grad_tws); | |||
std::vector<ValueRef> input_grads; | |||
for (auto&& input_grad_tw : input_grad_tws) { | |||
ValueRefList input_grads(input_grad_tws.size()); | |||
for (size_t i = 0; i < input_grad_tws.size(); ++i) { | |||
auto input_grad_tw = input_grad_tws[i]; | |||
if (!input_grad_tw.is_none()) { | |||
input_grads.push_back( | |||
py::cast<TensorWrapper>(input_grad_tw).m_tensor->data()); | |||
input_grads[i] = | |||
py::cast<TensorWrapper>(input_grad_tw).m_tensor->data(); | |||
} else { | |||
input_grads.push_back({}); | |||
input_grads[i] = {}; | |||
} | |||
} | |||
return input_grads; | |||
}; | |||
SmallVector<ValueRef> values; | |||
for (auto&& input : inputs) { | |||
values.push_back(input.cast<TensorWrapper>().m_tensor->data()); | |||
ValueRefList values(inputs.size() + outputs.size()); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data(); | |||
} | |||
for (auto&& output : outputs) { | |||
values.push_back(output.cast<TensorWrapper>().m_tensor->data()); | |||
for (size_t i = 0; i < outputs.size(); ++i) { | |||
values[i + inputs.size()] = | |||
outputs[i].cast<TensorWrapper>().m_tensor->data(); | |||
} | |||
auto wrapped_output_values = imperative::apply( | |||
SetGrad(key->m_key, generic_backward_fn, inputs.size()), values); | |||
@@ -39,7 +39,7 @@ namespace mgb::imperative::python { | |||
extern interpreter::Interpreter::Channel* interpreter_for_py; | |||
extern PyTypeObject* py_tensor_type; | |||
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||
struct Tensor : NonCopyableObj { | |||
private: | |||
std::string m_name; | |||
ValueRef m_data; | |||
@@ -52,7 +52,7 @@ public: | |||
~Tensor() = default; | |||
inline std::shared_ptr<Tensor> copy() { | |||
auto ret = std::make_shared<Tensor>(m_data.unwrap()); | |||
auto ret = std::make_shared<Tensor>(m_data); | |||
ret->m_name = m_name; | |||
return ret; | |||
} | |||
@@ -11,7 +11,15 @@ | |||
#pragma once | |||
#include <optional> | |||
#include <string> | |||
#include "pybind11/pybind11.h" | |||
#include "megbrain/imperative/dispatch.h" | |||
#include "megbrain/imperative/transformation.h" | |||
#include "megbrain/imperative/value.h" | |||
#include "megbrain/utils/small_vector.h" | |||
namespace mgb::imperative::python { | |||
struct TransformationManager { | |||
@@ -58,4 +66,14 @@ struct TransformationManager { | |||
return sl_instance; | |||
} | |||
}; | |||
class PyValue final : public MixinValueImpl<PyValue, pybind11::object> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const { | |||
return pybind11::str((const pybind11::object&)*this).cast<std::string>(); | |||
} | |||
}; | |||
} // namespace mgb::imperative::python |
@@ -45,7 +45,7 @@ CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout) | |||
layout.is_contiguous() || layout.is_empty(), "layout should be contiguous"); | |||
} | |||
auto CreateTensor::parse(Span<ValueRef> inputs) -> Args { | |||
auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args { | |||
Args result; | |||
for (auto&& input : inputs) { | |||
if (auto host_storage = input.as_ref<HostStorage>()) { | |||
@@ -16,70 +16,67 @@ | |||
#include "megbrain/imperative/utils/map.h" | |||
namespace mgb { | |||
void imperative_log_profile_begin(const char* message); | |||
void imperative_log_profile(const char* message); | |||
void imperative_log_profile_end(const char* message); | |||
namespace imperative { | |||
std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs) { | |||
static bool log_dispatch = MGB_GETENV("MGE_LOG_OP_DISPATCH"); | |||
bool enable_watch = ValueRef::any_watching(); | |||
auto& context = Transformation::get_context(); | |||
size_t& depth = context.next_transformation; | |||
static const char tabs_storage[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t"; | |||
const char* tabs = tabs_storage + sizeof(tabs_storage) / sizeof(char) - depth - 1; | |||
bool log_current_dispatch = log_dispatch; | |||
if (enable_watch) { | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
auto& input = inputs[i]; | |||
if (input.watching()) { | |||
log_current_dispatch = true; | |||
mgb_log_debug("%sinput[%zu] is %s", tabs, i, input.to_string().c_str()); | |||
debug::notify_event("apply"); | |||
} | |||
} | |||
} | |||
// entrance | |||
std::vector<ValueRef> outputs; | |||
if (depth >= context.transformations.size()) { | |||
// fallback | |||
if (log_current_dispatch) { | |||
mgb_log_debug( | |||
"%sfallback apply %s in %s", tabs, op.to_string().c_str(), | |||
imperative::to_string(inputs).c_str()); | |||
namespace { | |||
MGB_NOINLINE void copy_outputs( | |||
ForwardAllocator<ValueRef>& allocator, ValueRefList& outputs) { | |||
size_t nr_outputs = outputs.size(); | |||
if (mgb_likely(nr_outputs == 1)) { | |||
ValueRef output_copy; | |||
output_copy = outputs[0]; | |||
allocator.clear(); | |||
outputs = ValueRefList({output_copy}); | |||
} else if (!outputs.empty()) { | |||
SmallVector<ValueRef> outputs_copy(nr_outputs); | |||
for (size_t i = 0; i < nr_outputs; ++i) { | |||
outputs_copy[i] = outputs[i]; | |||
} | |||
outputs = op.fallback(inputs); | |||
outputs.clear(); | |||
allocator.clear(); | |||
outputs = {outputs_copy.begin(), outputs_copy.end()}; | |||
} else { | |||
// dispatch to stack top | |||
auto& transformation = *context.transformations[depth]; | |||
++depth; | |||
context.frames.push_back({op, inputs}); | |||
CleanupGuard _{[&] { | |||
context.frames.pop_back(); | |||
--depth; | |||
}}; | |||
if (log_current_dispatch) { | |||
mgb_log_debug( | |||
"%s%s apply %s in %s", tabs, transformation.name().c_str(), | |||
op.to_string().c_str(), imperative::to_string(inputs).c_str()); | |||
} | |||
outputs = transformation.apply_transformation(op, inputs); | |||
allocator.clear(); | |||
} | |||
if (log_current_dispatch) { | |||
mgb_log_debug("%sreturn %s", tabs, imperative::to_string(outputs).c_str()); | |||
} | |||
} // namespace | |||
ValueRefList apply(const Operator& op, Span<ValueRef> inputs) { | |||
auto& context = Transformation::get_context(); | |||
size_t& depth = context.next_transformation; | |||
bool top = depth == 0; | |||
auto outputs = ([&] { | |||
if (mgb_unlikely(depth >= context.transformations.size())) { | |||
return op.fallback(inputs); | |||
} else { | |||
auto& transformation = *context.transformations[depth++]; | |||
CleanupGuard _{[&] { --depth; }}; | |||
return transformation.apply_transformation(op, inputs); | |||
} | |||
})(); | |||
if (mgb_unlikely(top)) { | |||
copy_outputs(context.allocator, outputs); | |||
} | |||
return outputs; | |||
} | |||
std::vector<ValueRef> apply(const OpDef& def, Span<ValueRef> inputs) { | |||
ValueRefList apply(const OpDef& def, Span<ValueRef> inputs) { | |||
return imperative::apply(ApplyOp{def}, inputs); | |||
} | |||
std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs) { | |||
ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs) { | |||
SmallVector<ValueRef> inputs_storage; | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
inputs_storage.push_back(inputs[i]); | |||
} | |||
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<ValueRef> inputs, | |||
size_t) { | |||
auto outputs = imperative::apply(ApplyOp(*op), inputs); | |||
auto outputs = imperative::apply(*op, inputs); | |||
return SmallVector<ValueRef>(outputs.begin(), outputs.end()); | |||
}; | |||
auto make_const = [](TensorPtr constant) -> ValueRef { | |||
@@ -101,7 +98,7 @@ std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs) { | |||
DeviceStorage::make(device_value.storage()))[0]; | |||
}; | |||
auto outputs = graph.apply(inputs_storage, apply_functor, make_const); | |||
return {outputs.begin(), outputs.end()}; | |||
return ValueRefList{outputs.begin(), outputs.end()}; | |||
} | |||
} // namespace imperative | |||
@@ -126,7 +126,7 @@ public: | |||
m_frames[m_frames.size() - 1 - i] = {node, node->version()}; | |||
node = node->parent(); | |||
} | |||
mgb_assert(node->is_root(), ""); | |||
mgb_assert(node->is_root()); | |||
} | |||
Trace() = default; | |||
std::string to_string() const { | |||
@@ -3,7 +3,7 @@ | |||
namespace mgb { | |||
namespace imperative { | |||
std::vector<ValueRef> Operator::fallback(Span<ValueRef> inputs) const { | |||
ValueRefList Operator::fallback(Span<ValueRef> inputs) const { | |||
mgb_throw(MegBrainError, "no fallback implementation for %s", to_string().c_str()); | |||
} | |||
@@ -99,19 +99,22 @@ Tensor::Tensor( | |||
Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) { | |||
constexpr int size_threshold = TensorShape::MAX_NDIM; | |||
if (hv.layout().total_nr_elems() <= size_threshold) { | |||
size_t nr_elems = hv.layout().total_nr_elems(); | |||
if (nr_elems <= size_threshold) { | |||
m_value = hv; | |||
} | |||
MGB_RECORD_EVENT( | |||
profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(), | |||
dev_tensor().raw_ptr()); | |||
dev_tensor().copy_from_fixlayout(hv); | |||
// even though hv is saved in m_value, Tensor itself could be | |||
// released before copy completes | |||
MGB_RECORD_EVENT( | |||
profiler::HostToDeviceFinishEvent, hv.layout(), hv.comp_node(), | |||
hv.raw_ptr(), dev_tensor().raw_ptr()); | |||
AsyncReleaser::inst()->add(hv); | |||
if (nr_elems) { | |||
MGB_RECORD_EVENT( | |||
profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(), | |||
dev_tensor().raw_ptr()); | |||
dev_tensor().copy_from_fixlayout(hv); | |||
// even though hv is saved in m_value, Tensor itself could be | |||
// released before copy completes | |||
MGB_RECORD_EVENT( | |||
profiler::HostToDeviceFinishEvent, hv.layout(), hv.comp_node(), | |||
hv.raw_ptr(), dev_tensor().raw_ptr()); | |||
AsyncReleaser::inst()->add(hv); | |||
} | |||
} | |||
Tensor::Tensor(const DeviceTensorND& dv, const HostTensorND& hv) { | |||
@@ -310,7 +310,8 @@ struct ChromeTimelineEventVisitor : EventVisitor<ChromeTimelineEventVisitor> { | |||
} else if constexpr (std::is_same_v<TEvent, TensorGetPropEvent>) { | |||
new_host_event("TensorGetProp", 'X') | |||
.dur(0) | |||
.args(current_tensor->detail(current->time)); | |||
.args(current_tensor->detail(current->time)) | |||
.arg("kind", imperative::to_string(event.prop)); | |||
} else if constexpr (std::is_same_v<TEvent, TensorWaitPropEvent>) { | |||
new_host_event("TensorWaitProp", 'B'); | |||
} else if constexpr (std::is_same_v<TEvent, TensorWaitPropFinishEvent>) { | |||
@@ -15,71 +15,109 @@ | |||
namespace mgb { | |||
namespace imperative { | |||
std::vector<ValueRef> InterpreterTransformation::apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) { | |||
if (auto* op_val = op.as<ApplyOp>()) { | |||
if (op_val->op().same_type<FastpathCopy>()) { | |||
return {inputs[0]}; | |||
} | |||
SmallVector<Handle> input_handles; | |||
SmallVector<Handle> output_handles; | |||
CleanupGuard _{[&] { | |||
for (auto handle : output_handles) { | |||
if (handle) { | |||
m_channel->del(handle); | |||
} | |||
DTypeValue::ref_t InterpreterInfo::dtype() const { | |||
if (!m_dtype) { | |||
m_dtype = DTypeValue::make(handle()->channel()->get_dtype(handle()->handle())); | |||
} | |||
return m_dtype; | |||
} | |||
CompNodeValue::ref_t InterpreterInfo::comp_node() const { | |||
if (!m_comp_node) { | |||
m_comp_node = CompNodeValue::make( | |||
handle()->channel()->get_device(handle()->handle())); | |||
} | |||
return m_comp_node; | |||
} | |||
ShapeValue::ref_t InterpreterInfo::shape() const { | |||
if (!m_shape) { | |||
m_shape = ShapeValue::make( | |||
ValueShape::from(handle()->channel()->get_shape(handle()->handle()))); | |||
} | |||
return m_shape; | |||
} | |||
ValueRefList InterpreterTransformation::apply_op( | |||
const ApplyOp& apply_op, Span<ValueRef> inputs) { | |||
if (apply_op.op().same_type<FastpathCopy>()) { | |||
return {inputs[0]}; | |||
} | |||
SmallVector<Handle> input_handles; | |||
SmallVector<Handle> output_handles; | |||
CleanupGuard _{[&] { | |||
for (auto handle : output_handles) { | |||
if (handle) { | |||
m_channel->del(handle); | |||
} | |||
}}; | |||
for (auto input : inputs) { | |||
input_handles.push_back(*input.cast<InterpreterValue>().handle()); | |||
} | |||
output_handles = | |||
m_channel->apply_op(op_val->op().shared_from_this(), input_handles); | |||
std::vector<ValueRef> outputs; | |||
for (auto& handle : output_handles) { | |||
outputs.push_back(InterpreterValue::make(share_handle(handle))); | |||
handle = nullptr; | |||
} | |||
return outputs; | |||
}}; | |||
for (auto input : inputs) { | |||
input_handles.push_back(input.cast<InterpreterValue>().handle()->handle()); | |||
} | |||
output_handles = | |||
m_channel->apply_op(apply_op.op().shared_from_this(), input_handles); | |||
ValueRefList outputs(output_handles.size()); | |||
for (size_t i = 0; i < output_handles.size(); ++i) { | |||
outputs[i] = InterpreterValue::make(share_handle(output_handles[i])); | |||
output_handles[i] = nullptr; | |||
} | |||
return outputs; | |||
} | |||
ValueRefList InterpreterTransformation::apply_get_attr( | |||
const GetAttr& get_attr, Span<ValueRef> inputs) { | |||
auto& input = inputs.item().cast<InterpreterValue>(); | |||
ValueRef output; | |||
switch (get_attr.attr()) { | |||
case GetAttr::DType: | |||
output = input.dtype(); | |||
break; | |||
case GetAttr::Shape: | |||
output = input.shape(); | |||
break; | |||
case GetAttr::Device: | |||
output = input.comp_node(); | |||
break; | |||
case GetAttr::Value: | |||
output = HostValue::make(m_channel->get_value(input.handle()->handle())); | |||
break; | |||
case GetAttr::Data: | |||
output = DeviceValue::make( | |||
m_channel->get_dev_tensor(input.handle()->handle())); | |||
break; | |||
default: | |||
mgb_throw( | |||
MegBrainError, "Interpreter: malformed GetAttr: %s", | |||
get_attr.to_string().c_str()); | |||
} | |||
return {output}; | |||
} | |||
ValueRefList InterpreterTransformation::apply_create_tensor( | |||
const CreateTensor& create_tensor, Span<ValueRef> inputs) { | |||
auto args = create_tensor.parse(inputs); | |||
if (!args.device) { | |||
// implies H2D | |||
mgb_assert(args.host, "neither host and device value is valid"); | |||
return {InterpreterValue::make(share_handle( | |||
m_channel->put(*args.host, args.kind == CreateTensor::Unique)))}; | |||
} else { | |||
return {InterpreterValue::make(share_handle(m_channel->put( | |||
*args.device, args.host ? *args.host : HostTensorND())))}; | |||
} | |||
} | |||
ValueRefList InterpreterTransformation::apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) { | |||
if (auto* op_val = op.as<ApplyOp>()) { | |||
return apply_op(*op_val, inputs); | |||
} else if (auto* get_attr = op.as<GetAttr>()) { | |||
Handle handle = *inputs[0].cast<InterpreterValue>().handle(); | |||
ValueRef output; | |||
switch (get_attr->attr()) { | |||
case GetAttr::DType: | |||
output = DTypeValue::make(m_channel->get_dtype(handle)); | |||
break; | |||
case GetAttr::Shape: | |||
output = ShapeValue::make( | |||
ValueShape::from(m_channel->get_shape(handle))); | |||
break; | |||
case GetAttr::Device: | |||
output = CompNodeValue::make(m_channel->get_device(handle)); | |||
break; | |||
case GetAttr::Value: | |||
output = HostValue::make(m_channel->get_value(handle)); | |||
break; | |||
case GetAttr::Data: | |||
output = DeviceValue::make(m_channel->get_dev_tensor(handle)); | |||
break; | |||
default: | |||
mgb_throw( | |||
MegBrainError, "Interpreter: malformed GetAttr: %s", | |||
op.to_string().c_str()); | |||
} | |||
return {output}; | |||
return apply_get_attr(*get_attr, inputs); | |||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
auto args = create_tensor->parse(inputs); | |||
if (!args.device) { | |||
// implies H2D | |||
mgb_assert(args.host, "neither host and device value is valid"); | |||
return {InterpreterValue::make(share_handle( | |||
m_channel->put(*args.host, args.kind == CreateTensor::Unique)))}; | |||
} else { | |||
return {InterpreterValue::make(share_handle(m_channel->put( | |||
*args.device, args.host ? *args.host : HostTensorND())))}; | |||
} | |||
return apply_create_tensor(*create_tensor, inputs); | |||
} else if (auto* dtr_command = op.as<DTRCommand>()) { | |||
auto handle = *inputs[0].cast<InterpreterValue>().handle(); | |||
auto handle = inputs[0].cast<InterpreterValue>().handle()->handle(); | |||
switch (dtr_command->kind()) { | |||
case DTRCommand::Drop: | |||
m_channel->drop(handle); | |||
@@ -64,12 +64,13 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( | |||
size_t count = std::count_if( | |||
save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); | |||
if (!backward_graph->precomp.empty()) { | |||
SmallVector<ValueRef> inputs_and_outputs; | |||
ValueRefList inputs_and_outputs(inputs.size() + outputs.size()); | |||
auto it = inputs_and_outputs.begin(); | |||
for (auto&& input : inputs) { | |||
inputs_and_outputs.push_back(input); | |||
*it++ = input; | |||
} | |||
for (auto&& output : outputs) { | |||
inputs_and_outputs.push_back(output); | |||
*it++ = output; | |||
} | |||
auto precomp = imperative::apply(backward_graph->precomp, inputs_and_outputs); | |||
closure.reserve(precomp.size() + count); | |||
@@ -89,7 +90,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( | |||
} | |||
} | |||
void BackwardGraphWithClosure::operator()( | |||
std::vector<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) { | |||
ValueRefList grads, std::function<void(size_t, ValueRef)> receiver) { | |||
ValueRef args[closure.size() + grads.size()]; | |||
size_t nargs = 0; | |||
for (auto&& value : closure) { | |||
@@ -120,7 +121,7 @@ void BackwardGraphWithClosure::operator()( | |||
} | |||
void CustomBackward::operator()( | |||
std::vector<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) { | |||
ValueRefList grads, std::function<void(size_t, ValueRef)> receiver) { | |||
size_t nargs = grads.size(); | |||
ValueRef args[nargs]; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
@@ -201,9 +202,10 @@ void GradKey::backward() { | |||
mgb_throw(AssertionError, "invalid backward"); | |||
} else { | |||
mgb_assert(grad_fn->m_slots.size() > 0); | |||
std::vector<ValueRef> grads; | |||
ValueRefList grads (grad_fn->m_slots.size()); | |||
auto iter = grads.begin(); | |||
for (auto&& slot : grad_fn->m_slots) { | |||
grads.push_back(slot.m_grad); | |||
*iter++ = slot.m_grad; | |||
} | |||
backward(grads, grad_receiver); | |||
} | |||
@@ -254,21 +256,28 @@ void GradKey::freeze() { | |||
m_frozen = true; | |||
} | |||
std::vector<ValueRef> GradTransformation::apply_transformation( | |||
ValueRefList GradTransformation::apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) { | |||
auto unwrap_inputs = [this](Span<ValueRef> inputs) -> SmallVector<ValueRef> { | |||
SmallVector<ValueRef> unwrapped_inputs; | |||
for (auto&& input : inputs) { | |||
if (auto grad_value = as_grad_value(input)) { | |||
unwrapped_inputs.push_back(grad_value->m_value); | |||
auto fallback = [&] { | |||
ValueRefList unwrapped_inputs(inputs.size()); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (auto grad_value = as_grad_value(inputs[i])) { | |||
unwrapped_inputs[i] = grad_value->m_value; | |||
} else { | |||
unwrapped_inputs.push_back(input); | |||
unwrapped_inputs[i] = inputs[i]; | |||
} | |||
} | |||
return unwrapped_inputs; | |||
return imperative::apply(op, unwrapped_inputs); | |||
}; | |||
if (auto* get_attr = op.as<GetAttr>()) { | |||
if (auto grad_value = as_grad_value(inputs.item())) { | |||
return imperative::apply(op, grad_value->m_value); | |||
} else { | |||
return imperative::apply(op, inputs); | |||
} | |||
} | |||
if (m_suppressed) { | |||
return imperative::apply(op, unwrap_inputs(inputs)); | |||
return fallback(); | |||
} | |||
if (auto* op_val = op.as<ApplyOp>()) { | |||
size_t nr_require_grad = 0; | |||
@@ -284,20 +293,21 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||
if (nr_require_grad == 0) { | |||
return imperative::apply(op, inputs); | |||
} | |||
SmallVector<ValueRef> captured_inputs; | |||
SmallVector<bool> inputs_require_grad; | |||
ValueRefList captured_inputs(inputs.size()); | |||
SmallVector<bool> inputs_require_grad(inputs.size()); | |||
// capture value so that trace could assume input as same | |||
auto capture_value = [](ValueRef value) { | |||
// TODO: fastpath copy shouldn't be an OpDef | |||
return imperative::apply(ApplyOp(*FastpathCopy::make()), {&value, 1})[0]; | |||
}; | |||
for (auto& input : inputs) { | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
auto& input = inputs[i]; | |||
if (auto grad_value = as_grad_value(input)) { | |||
captured_inputs.push_back(capture_value(grad_value->m_value)); | |||
inputs_require_grad.push_back(true); | |||
captured_inputs[i] = capture_value(grad_value->m_value); | |||
inputs_require_grad[i] = true; | |||
} else { | |||
captured_inputs.push_back(capture_value(input)); | |||
inputs_require_grad.push_back(false); | |||
captured_inputs[i] = capture_value(input); | |||
inputs_require_grad[i] = false; | |||
} | |||
} | |||
decltype(std::declval<GradFn>().m_backward) backward_storage; | |||
@@ -373,9 +383,11 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||
mgb_assert(!grad_fn->m_slots.empty()); | |||
m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()}); | |||
return outputs; | |||
} else if (op.is<CreateTensor>()) { | |||
return imperative::apply(op, inputs); | |||
} else if (auto* attach_grad = op.as<AttachGrad>()) { | |||
if (!has_key(attach_grad->key())) { | |||
return imperative::apply(op, unwrap_inputs(inputs)); | |||
return fallback(); | |||
} | |||
auto tensor = inputs[0]; | |||
GenericFunction callback = (GenericFunction&)inputs[1].cast<FunctionValue>(); | |||
@@ -386,7 +398,7 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||
return {record_grad(output)}; | |||
} else if (auto* grad_backward = op.as<GradBackward>()) { | |||
if (!has_key(grad_backward->key())) { | |||
return imperative::apply(op, unwrap_inputs(inputs)); | |||
return fallback(); | |||
} | |||
size_t nr_grads = inputs.size() / 2; | |||
mgb_assert(nr_grads * 2 == inputs.size()); | |||
@@ -416,7 +428,7 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||
backward.m_output_attrs = | |||
SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true}); | |||
backward.m_backward = set_grad->grad_fn(); | |||
std::vector<ValueRef> outputs; | |||
ValueRefList outputs(nr_outputs); | |||
grad_fn->m_key = m_key; | |||
grad_fn->m_slots.resize(nr_outputs); | |||
grad_fn->m_dests.reserve(nr_inputs); | |||
@@ -439,13 +451,13 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||
} else { | |||
grad_value = GradValue::make(output, m_key, GradSlotPtr(grad_fn, i)); | |||
} | |||
outputs.push_back(record_grad(grad_value)); | |||
outputs[i] = record_grad(grad_value); | |||
} | |||
m_key->m_tape.push_back({grad_fn, nullptr}); | |||
return outputs; | |||
} else if (auto* gbc = op.as<GetBackwardColsure>()) { | |||
if (gbc->key() != m_key) { | |||
return imperative::apply(op, unwrap_inputs(inputs)); | |||
return fallback(); | |||
} | |||
return {FunctionValue::make(make_backward_closure(inputs))}; | |||
} else if (op.is<DetachGrad>()) { | |||
@@ -471,21 +483,8 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||
} else { | |||
return imperative::apply(op, inputs); | |||
} | |||
} else if (op.is<CreateTensor>()) { | |||
return imperative::apply(op, inputs); | |||
} else { | |||
SmallVector<ValueRef> unwrapped_inputs; | |||
for (auto&& input : inputs) { | |||
if (auto grad_value = as_grad_value(input)) { | |||
unwrapped_inputs.push_back(grad_value->m_value); | |||
} else { | |||
unwrapped_inputs.push_back(input); | |||
} | |||
} | |||
auto outputs = imperative::apply( | |||
op, {unwrapped_inputs.data(), unwrapped_inputs.size()}); | |||
mgb_assert(op.kind() == Operator::GetAttrLike || outputs.empty()); | |||
return outputs; | |||
return fallback(); | |||
} | |||
} | |||
@@ -500,8 +499,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) { | |||
y_slots.emplace_back(); | |||
} | |||
} | |||
GenericFunction closure = [grad_key, | |||
y_slots](Span<ValueRef> dys) -> std::vector<ValueRef> { | |||
GenericFunction closure = [grad_key, y_slots](Span<ValueRef> dys) -> ValueRefList { | |||
size_t nr_grads = y_slots.size(); | |||
mgb_assert(dys.size() == nr_grads); | |||
for (size_t i = 0; i < nr_grads; ++i) { | |||
@@ -21,7 +21,7 @@ | |||
namespace mgb { | |||
namespace imperative { | |||
std::vector<ValueRef> LazyEvalTransformation::apply_transformation( | |||
ValueRefList LazyEvalTransformation::apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) { | |||
if (auto* op_val = op.as<ApplyOp>()) { | |||
static std::unordered_set<Typeinfo*> mm_io_ops = { | |||
@@ -59,9 +59,9 @@ std::vector<ValueRef> LazyEvalTransformation::apply_transformation( | |||
mgb_assert(!output_nodes.empty()); | |||
m_io_link = SymbolVar(output_nodes[0]); | |||
} | |||
std::vector<ValueRef> outputs; | |||
for (auto&& output_node : output_nodes) { | |||
outputs.push_back(record_var(output_node)); | |||
ValueRefList outputs(output_nodes.size()); | |||
for (size_t i = 0; i < output_nodes.size(); ++i) { | |||
outputs[i] = record_var(output_nodes[i]); | |||
} | |||
return outputs; | |||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
@@ -19,26 +19,8 @@ namespace imperative { | |||
namespace { | |||
using ScalarRule = std::function<std::vector<ValueRef>(const OpDef&, Span<ValueRef>)>; | |||
static std::unordered_map< | |||
Typeinfo*, std::function<std::vector<ValueRef>(const OpDef&, Span<ValueRef>)>> | |||
scalar_rules; | |||
ValueRef unwrap_input(ValueRef input) { | |||
if (auto scalar_input = input.as_ref<ScalarValue>()) { | |||
return scalar_input->value(); | |||
} else { | |||
return input; | |||
} | |||
} | |||
std::vector<ValueRef> unwrap_inputs(Span<ValueRef> inputs) { | |||
std::vector<ValueRef> unwrapped_inputs; | |||
for (auto&& input : inputs) { | |||
unwrapped_inputs.push_back(unwrap_input(input)); | |||
} | |||
return unwrapped_inputs; | |||
} | |||
using ScalarRule = ValueRefList (*)(const OpDef&, Span<ValueRef>, Span<bool>); | |||
static std::unordered_map<Typeinfo*, ScalarRule> scalar_rules; | |||
ValueRef make_scalar_shape(CompNode device) { | |||
HostTensorND scalar_shape(device, {1}, dtype::Int32()); | |||
@@ -49,9 +31,6 @@ ValueRef make_scalar_shape(CompNode device) { | |||
} | |||
bool is_scalar_shape(ValueRef shape) { | |||
if (shape.is<ScalarValue>()) { | |||
return false; | |||
} | |||
// may have performance issue | |||
auto shape_of_shape = shape.shape(); | |||
if (!shape_of_shape) { | |||
@@ -61,74 +40,65 @@ bool is_scalar_shape(ValueRef shape) { | |||
return *shape_of_shape == ValueShape{0}; | |||
} | |||
template <typename T> | |||
void register_scalar_rule(std::vector<ValueRef> (*rule)(const T&, Span<ValueRef>)) { | |||
scalar_rules[T::typeinfo()] = [rule](const OpDef& def, Span<ValueRef> inputs) { | |||
return (*rule)(def.cast_final_safe<T>(), inputs); | |||
template <typename T, ValueRefList (*rule)(const T&, Span<ValueRef>, Span<bool>)> | |||
void register_scalar_rule() { | |||
scalar_rules[T::typeinfo()] = [](const OpDef& def, Span<ValueRef> inputs, | |||
Span<bool> inputs_mask) { | |||
return (*rule)(def.cast_final_safe<T>(), inputs, inputs_mask); | |||
}; | |||
} | |||
std::vector<ValueRef> elemwise_rule(const Elemwise& elem, Span<ValueRef> inputs) { | |||
template <typename TOpDef, size_t nr_inputs> | |||
ValueRefList elemwise_rule( | |||
const TOpDef& op_def, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
if constexpr (nr_inputs != 0) { | |||
mgb_assert(inputs.size() == inputs.size(), "inputs size mismatch"); | |||
} | |||
bool all_scalar = true; | |||
for (auto&& input : inputs) { | |||
if (!input.is<ScalarValue>()) { | |||
for (auto&& input_mask : inputs_mask) { | |||
if (!input_mask) { | |||
all_scalar = false; | |||
break; | |||
} | |||
} | |||
auto output = imperative::apply(elem, unwrap_inputs(inputs))[0]; | |||
auto outputs = imperative::apply(op_def, inputs); | |||
if (all_scalar) { | |||
return {ScalarValue::make(output)}; | |||
} else { | |||
return {output}; | |||
outputs[0] = ScalarValue::make(outputs[0]); | |||
} | |||
return outputs; | |||
} | |||
std::vector<ValueRef> remove_axis_rule( | |||
const RemoveAxis& remove_axis, Span<ValueRef> inputs) { | |||
mgb_assert(inputs.size() == 1); | |||
mgb_assert(!inputs[0].is<ScalarValue>()); | |||
auto output = imperative::apply(remove_axis, inputs)[0]; | |||
bool is_scalar = inputs[0].shape()->ndim == remove_axis.axis.size(); | |||
ValueRefList remove_axis_rule( | |||
const RemoveAxis& remove_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
mgb_assert(!inputs_mask.item()); | |||
bool is_scalar = inputs.item().shape()->ndim == remove_axis.axis.size(); | |||
if (is_scalar && remove_axis.axis.size() == 1) { | |||
return {ScalarValue::make(inputs.item())}; | |||
} | |||
auto outputs = imperative::apply(remove_axis, inputs); | |||
if (is_scalar) { | |||
return {ScalarValue::make(output)}; | |||
} else { | |||
return {output}; | |||
outputs[0] = ScalarValue::make(outputs[0]); | |||
} | |||
return outputs; | |||
} | |||
std::vector<ValueRef> reduce_rule(const Reduce& reduce, Span<ValueRef> inputs) { | |||
ValueRefList reduce_rule( | |||
const Reduce& reduce, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
if (inputs.size() == 1) { | |||
return imperative::apply(reduce, unwrap_inputs(inputs)); | |||
return imperative::apply(reduce, inputs); | |||
} | |||
mgb_assert(inputs.size() == 2); | |||
bool is_scalar = is_scalar_shape(inputs[1]); | |||
if (is_scalar) { | |||
auto unwrapped_input = unwrap_input(inputs[0]); | |||
CompNode device = *unwrapped_input.device(); | |||
return {ScalarValue::make(imperative::apply( | |||
reduce, unwrapped_input, make_scalar_shape(device))[0])}; | |||
} | |||
auto output = imperative::apply(reduce, unwrap_inputs(inputs))[0]; | |||
if (is_scalar) { | |||
return {ScalarValue::make(output)}; | |||
} else { | |||
return {output}; | |||
} | |||
} | |||
std::vector<ValueRef> typecvt_rule(const TypeCvt& typecvt, Span<ValueRef> inputs) { | |||
mgb_assert(inputs.size() == 1); | |||
if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { | |||
CompNode device = *inputs[0].device(); | |||
return {ScalarValue::make( | |||
imperative::apply(typecvt, scalar_input->value())[0])}; | |||
} else { | |||
return imperative::apply(typecvt, inputs); | |||
imperative::apply(reduce, inputs[0], make_scalar_shape(device))[0])}; | |||
} | |||
return imperative::apply(reduce, inputs); | |||
} | |||
std::vector<ValueRef> collective_comm_rule( | |||
const CollectiveComm& collective_comm, Span<ValueRef> inputs) { | |||
ValueRefList collective_comm_rule( | |||
const CollectiveComm& collective_comm, Span<ValueRef> inputs, | |||
Span<bool> inputs_mask) { | |||
mgb_assert(inputs.size() == 1); | |||
static std::unordered_set<CollectiveComm::Mode> modes = { | |||
CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN, | |||
@@ -138,17 +108,17 @@ std::vector<ValueRef> collective_comm_rule( | |||
if (modes.count(collective_comm.mode) == 0) { | |||
return imperative::apply(collective_comm, inputs); | |||
} | |||
if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { | |||
return {ScalarValue::make( | |||
imperative::apply(collective_comm, scalar_input->value())[0])}; | |||
if (inputs_mask.item()) { | |||
return {ScalarValue::make(imperative::apply(collective_comm, inputs[0])[0])}; | |||
} else { | |||
return imperative::apply(collective_comm, inputs); | |||
} | |||
} | |||
std::vector<ValueRef> param_pack_split_rule( | |||
const ParamPackSplit& param_pack_split, Span<ValueRef> inputs) { | |||
auto outputs = imperative::apply(param_pack_split, unwrap_inputs(inputs)); | |||
ValueRefList param_pack_split_rule( | |||
const ParamPackSplit& param_pack_split, Span<ValueRef> inputs, | |||
Span<bool> inputs_mask) { | |||
auto outputs = imperative::apply(param_pack_split, inputs); | |||
size_t nr_outputs = outputs.size(); | |||
mgb_assert(nr_outputs == param_pack_split.shapes.size()); | |||
for (size_t i = 0; i < nr_outputs; ++i) { | |||
@@ -159,29 +129,28 @@ std::vector<ValueRef> param_pack_split_rule( | |||
return outputs; | |||
} | |||
std::vector<ValueRef> dot_rule(const Dot& dot, Span<ValueRef> inputs) { | |||
return {ScalarValue::make(imperative::apply(dot, unwrap_inputs(inputs))[0])}; | |||
ValueRefList dot_rule(const Dot& dot, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
return {ScalarValue::make(imperative::apply(dot, inputs)[0])}; | |||
} | |||
std::vector<ValueRef> add_axis_rule(const AddAxis& add_axis, Span<ValueRef> inputs) { | |||
ValueRefList add_axis_rule( | |||
const AddAxis& add_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
mgb_assert(inputs.size() == 1); | |||
if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { | |||
if (inputs_mask.item()) { | |||
mgb_assert(add_axis.axis[0] == 0); | |||
if (add_axis.axis.size() == 1) { | |||
return {scalar_input->value()}; | |||
return {inputs[0]}; | |||
} else { | |||
std::vector<int32_t> axis(add_axis.axis.begin() + 1, add_axis.axis.end()); | |||
return imperative::apply( | |||
ApplyOp(*AddAxis::make(axis, add_axis.scope())), | |||
scalar_input->value()); | |||
return imperative::apply(*AddAxis::make(axis, add_axis.scope()), inputs[0]); | |||
} | |||
} else { | |||
return imperative::apply(add_axis, inputs); | |||
} | |||
} | |||
std::vector<ValueRef> remote_recv_rule( | |||
const RemoteRecv& remote_recv, Span<ValueRef> inputs) { | |||
ValueRefList remote_recv_rule( | |||
const RemoteRecv& remote_recv, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
if (remote_recv.shape.empty()) { | |||
std::vector<int32_t> shape = {1}; | |||
auto remote_recv_no_scalar = RemoteRecv::make( | |||
@@ -189,32 +158,32 @@ std::vector<ValueRef> remote_recv_rule( | |||
remote_recv.rank_from, remote_recv.cn, shape, remote_recv.dtype, | |||
remote_recv.backend); | |||
remote_recv_no_scalar->set_scope(remote_recv.scope()); | |||
return imperative::apply( | |||
ApplyOp(*remote_recv_no_scalar), unwrap_inputs(inputs)); | |||
return imperative::apply(ApplyOp(*remote_recv_no_scalar), inputs); | |||
} else { | |||
return imperative::apply(remote_recv, unwrap_inputs(inputs)); | |||
return imperative::apply(remote_recv, inputs); | |||
} | |||
} | |||
std::vector<ValueRef> check_no_finite_rule( | |||
const CheckNonFinite& check_no_finite, Span<ValueRef> inputs) { | |||
auto outputs = imperative::apply(check_no_finite, unwrap_inputs(inputs)); | |||
ValueRefList check_no_finite_rule( | |||
const CheckNonFinite& check_no_finite, Span<ValueRef> inputs, | |||
Span<bool> inputs_mask) { | |||
auto outputs = imperative::apply(check_no_finite, inputs); | |||
mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch"); | |||
outputs.back() = ScalarValue::make(outputs.back()); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (inputs[i].is<ScalarValue>()) { | |||
if (inputs_mask[i]) { | |||
outputs[i] = ScalarValue::make(outputs[i]); | |||
} | |||
} | |||
return outputs; | |||
} | |||
std::vector<ValueRef> subtensor_rule( | |||
const Subtensor& subtensor, Span<ValueRef> inputs) { | |||
ValueRefList subtensor_rule( | |||
const Subtensor& subtensor, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
mgb_assert(inputs.size() >= 1); | |||
auto input = inputs[0]; | |||
bool is_scalar; | |||
mgb_assert(!input.is<ScalarValue>(), "subtensor shouldn't have scalar input"); | |||
mgb_assert(!inputs_mask[0], "subtensor shouldn't have scalar input"); | |||
if (auto shape = input.shape()) { | |||
size_t ndim = input.shape()->ndim; | |||
for (auto&& [axis, begin, end, step, idx] : subtensor.items) { | |||
@@ -226,25 +195,25 @@ std::vector<ValueRef> subtensor_rule( | |||
} else { | |||
is_scalar = false; | |||
} | |||
auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0]; | |||
auto outputs = imperative::apply(subtensor, inputs); | |||
if (is_scalar) { | |||
return {ScalarValue::make(output)}; | |||
} else { | |||
return {output}; | |||
outputs[0] = ScalarValue::make(outputs[0]); | |||
} | |||
return outputs; | |||
} | |||
std::vector<ValueRef> get_var_shape_rule( | |||
const GetVarShape& get_var_shape, Span<ValueRef> inputs) { | |||
ValueRefList get_var_shape_rule( | |||
const GetVarShape& get_var_shape, Span<ValueRef> inputs, | |||
Span<bool> inputs_mask) { | |||
bool all_scalar = true; | |||
mgb_assert(inputs.size() >= 1); | |||
for (auto&& input : inputs) { | |||
if (!input.is<ScalarValue>()) { | |||
for (auto&& input_mask : inputs_mask) { | |||
if (!input_mask) { | |||
all_scalar = false; | |||
} | |||
} | |||
if (all_scalar) { | |||
auto device = inputs[0].cast<ScalarValue>().value().device(); | |||
auto device = inputs[0].device(); | |||
auto storage = HostStorage::make(*device); | |||
// storage->ensure_size(1); | |||
return imperative::apply( | |||
@@ -252,88 +221,49 @@ std::vector<ValueRef> get_var_shape_rule( | |||
CreateTensor::Const, *device, dtype::Int32(), ValueShape{0}), | |||
storage); | |||
} else { | |||
return imperative::apply(get_var_shape, unwrap_inputs(inputs)); | |||
} | |||
} | |||
std::vector<ValueRef> fastpath_copy_rule( | |||
const FastpathCopy& fastpath_copy, Span<ValueRef> inputs) { | |||
mgb_assert(inputs.size() == 1); | |||
bool is_scalar = inputs[0].is<ScalarValue>(); | |||
auto output = imperative::apply(fastpath_copy, unwrap_inputs(inputs))[0]; | |||
if (is_scalar) { | |||
return {ScalarValue::make(output)}; | |||
} else { | |||
return {output}; | |||
return imperative::apply(get_var_shape, inputs); | |||
} | |||
} | |||
std::vector<ValueRef> reshape_rule(const Reshape& reshape, Span<ValueRef> inputs) { | |||
ValueRefList reshape_rule( | |||
const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
mgb_assert(inputs.size() == 2); | |||
bool is_scalar = is_scalar_shape(inputs[1]); | |||
auto unwrapped_input = inputs[0].is<ScalarValue>() | |||
? inputs[0].cast<ScalarValue>().value() | |||
: inputs[0]; | |||
if (is_scalar) { | |||
return {ScalarValue::make(imperative::apply( | |||
reshape, unwrapped_input, | |||
make_scalar_shape(*unwrapped_input.device()))[0])}; | |||
reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | |||
} else { | |||
return imperative::apply(reshape, unwrap_inputs(inputs)); | |||
return imperative::apply(reshape, inputs); | |||
} | |||
} | |||
std::vector<ValueRef> broadcast_rule( | |||
const Broadcast& broadcast, Span<ValueRef> inputs) { | |||
ValueRefList broadcast_rule( | |||
const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
mgb_assert(inputs.size() == 2); | |||
bool is_scalar = is_scalar_shape(inputs[1]); | |||
auto unwrapped_input = inputs[0].is<ScalarValue>() | |||
? inputs[0].cast<ScalarValue>().value() | |||
: inputs[0]; | |||
if (is_scalar) { | |||
return {ScalarValue::make(imperative::apply( | |||
broadcast, unwrapped_input, | |||
make_scalar_shape(*unwrapped_input.device()))[0])}; | |||
} else { | |||
return imperative::apply(broadcast, unwrap_inputs(inputs)); | |||
} | |||
} | |||
std::vector<ValueRef> copy_rule(const Copy& copy, Span<ValueRef> inputs) { | |||
mgb_assert(inputs.size() == 1); | |||
bool is_scalar = inputs[0].is<ScalarValue>(); | |||
if (is_scalar) { | |||
return {ScalarValue::make(imperative::apply(copy, unwrap_inputs(inputs))[0])}; | |||
} else { | |||
return imperative::apply(copy, unwrap_inputs(inputs)); | |||
} | |||
} | |||
std::vector<ValueRef> inplace_add_rule( | |||
const InplaceAdd& inplace_add, Span<ValueRef> inputs) { | |||
mgb_assert(inputs.size() == 4); | |||
bool is_scalar = inputs[0].is<ScalarValue>(); | |||
if (is_scalar) { | |||
return {ScalarValue::make( | |||
imperative::apply(inplace_add, unwrap_inputs(inputs))[0])}; | |||
broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | |||
} else { | |||
return imperative::apply(inplace_add, unwrap_inputs(inputs)); | |||
return imperative::apply(broadcast, inputs); | |||
} | |||
} | |||
template <typename T> | |||
std::vector<ValueRef> subgraph_op_rule(const T& op, Span<ValueRef> inputs) { | |||
ValueRefList subgraph_op_rule( | |||
const T& op, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
const Type<ScalarValue>& scalar_type) { | |||
// TODO: add flag instead of assume | |||
bool all_scalar = true; | |||
for (auto&& input : inputs) { | |||
if (!input.is<ScalarValue>()) { | |||
for (auto&& input_mask : inputs_mask) { | |||
if (!input_mask) { | |||
all_scalar = false; | |||
} | |||
} | |||
auto outputs = imperative::apply(op, unwrap_inputs(inputs)); | |||
auto outputs = imperative::apply(op, inputs); | |||
if (all_scalar) { | |||
for (auto& output : outputs) { | |||
output = ScalarValue::make(output); | |||
output = scalar_type.make(output); | |||
} | |||
} | |||
return outputs; | |||
@@ -341,67 +271,54 @@ std::vector<ValueRef> subgraph_op_rule(const T& op, Span<ValueRef> inputs) { | |||
struct ScalarRuleRegistry { | |||
ScalarRuleRegistry() { | |||
register_scalar_rule(elemwise_rule); | |||
register_scalar_rule(remove_axis_rule); | |||
register_scalar_rule(reduce_rule); | |||
register_scalar_rule(typecvt_rule); | |||
register_scalar_rule(collective_comm_rule); | |||
register_scalar_rule(param_pack_split_rule); | |||
register_scalar_rule(dot_rule); | |||
register_scalar_rule(add_axis_rule); | |||
register_scalar_rule(remote_recv_rule); | |||
register_scalar_rule(check_no_finite_rule); | |||
register_scalar_rule(subtensor_rule); | |||
register_scalar_rule(get_var_shape_rule); | |||
register_scalar_rule(fastpath_copy_rule); | |||
register_scalar_rule(reshape_rule); | |||
register_scalar_rule(broadcast_rule); | |||
register_scalar_rule(copy_rule); | |||
register_scalar_rule(inplace_add_rule); | |||
register_scalar_rule(subgraph_op_rule<SubgraphOp>); | |||
register_scalar_rule(subgraph_op_rule<CompiledOp>); | |||
register_scalar_rule<Elemwise, elemwise_rule<Elemwise, 0>>(); | |||
register_scalar_rule<RemoveAxis, remove_axis_rule>(); | |||
register_scalar_rule<Reduce, reduce_rule>(); | |||
register_scalar_rule<TypeCvt, elemwise_rule<TypeCvt, 1>>(); | |||
register_scalar_rule<CollectiveComm, collective_comm_rule>(); | |||
register_scalar_rule<ParamPackSplit, param_pack_split_rule>(); | |||
register_scalar_rule<Dot, dot_rule>(); | |||
register_scalar_rule<AddAxis, add_axis_rule>(); | |||
register_scalar_rule<RemoteRecv, remote_recv_rule>(); | |||
register_scalar_rule<CheckNonFinite, check_no_finite_rule>(); | |||
register_scalar_rule<Subtensor, subtensor_rule>(); | |||
register_scalar_rule<GetVarShape, get_var_shape_rule>(); | |||
register_scalar_rule<FastpathCopy, elemwise_rule<FastpathCopy, 1>>(); | |||
register_scalar_rule<Reshape, reshape_rule>(); | |||
register_scalar_rule<Broadcast, broadcast_rule>(); | |||
register_scalar_rule<Copy, elemwise_rule<Copy, 1>>(); | |||
register_scalar_rule<InplaceAdd, elemwise_rule<InplaceAdd, 4>>(); | |||
register_scalar_rule<SubgraphOp, subgraph_op_rule<SubgraphOp>>(); | |||
register_scalar_rule<CompiledOp, subgraph_op_rule<CompiledOp>>(); | |||
} | |||
} _; | |||
} // namespace | |||
std::vector<ValueRef> ScalarTransformation::apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) { | |||
if (auto apply_op = op.as<ApplyOp>()) { | |||
auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); | |||
if (iter != scalar_rules.end()) { | |||
return iter->second(apply_op->op(), inputs); | |||
} else { | |||
// TODO: repeat op | |||
return imperative::apply(op, unwrap_inputs(inputs)); | |||
} | |||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
if (create_tensor->shape().is_scalar()) { | |||
ValueShape scalar_shape = {1}; | |||
CreateTensor scalar_op( | |||
create_tensor->kind(), create_tensor->device(), | |||
create_tensor->dtype(), scalar_shape); | |||
return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])}; | |||
} else { | |||
return imperative::apply(op, inputs); | |||
} | |||
} else if (auto* get_attr = op.as<GetAttr>()) { | |||
bool is_scalar = inputs.as_array<1>()[0].is<ScalarValue>(); | |||
auto output = imperative::apply(op, unwrap_inputs(inputs))[0]; | |||
if (!is_scalar) { | |||
return {output}; | |||
ValueRefList ScalarTransformation::apply_get_attr( | |||
const GetAttr& get_attr, Span<ValueRef> inputs) { | |||
auto&& input = inputs.item(); | |||
bool is_scalar = input.is<ScalarValue>(); | |||
if (!is_scalar) { | |||
return imperative::apply(get_attr, input); | |||
} | |||
auto unwrapped_input = input.cast<ScalarValue>().value(); | |||
if (get_attr.attr() == GetAttr::Shape) { | |||
if (!m_empty_shape) { | |||
m_empty_shape = ShapeValue::make(); | |||
} | |||
switch (get_attr->attr()) { | |||
case GetAttr::Shape: { | |||
// Scalar Shape | |||
return {ShapeValue::make()}; | |||
} | |||
return {m_empty_shape}; | |||
} else { | |||
auto outputs = imperative::apply(get_attr, unwrapped_input); | |||
auto& output = outputs[0]; | |||
switch (get_attr.attr()) { | |||
case GetAttr::Value: { | |||
auto& hv = output.cast<HostValue>(); | |||
mgb_assert( | |||
hv.shape() == ValueShape({1}), | |||
"underlying value should has shape {1}, got %s", | |||
hv.shape().to_string().c_str()); | |||
return {HostValue::make(hv.dtype(), ValueShape(), hv.storage())}; | |||
output = HostValue::make(hv.dtype(), ValueShape(), hv.storage()); | |||
break; | |||
} | |||
case GetAttr::Data: { | |||
auto& dv = output.cast<DeviceValue>(); | |||
@@ -409,22 +326,67 @@ std::vector<ValueRef> ScalarTransformation::apply_transformation( | |||
dv.shape() == ValueShape({1}), | |||
"underlying value should has shape {1}, got %s", | |||
dv.shape().to_string().c_str()); | |||
return {DeviceValue::make(dv.dtype(), ValueShape(), dv.storage())}; | |||
output = DeviceValue::make(dv.dtype(), ValueShape(), dv.storage()); | |||
break; | |||
} | |||
default: | |||
return {output}; | |||
break; | |||
} | |||
return outputs; | |||
} | |||
} | |||
ValueRefList ScalarTransformation::apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) { | |||
if (auto* get_attr = op.as<GetAttr>()) { | |||
// fastpath for GetAttr | |||
return apply_get_attr(*get_attr, inputs); | |||
} | |||
size_t nr_inputs = inputs.size(); | |||
ValueRefList unwrapped_inputs(nr_inputs); | |||
bool inputs_mask[nr_inputs]; | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (auto scalar_value = inputs[i].as_ref<ScalarValue>()) { | |||
unwrapped_inputs[i] = scalar_value->value(); | |||
inputs_mask[i] = true; | |||
} else { | |||
unwrapped_inputs[i] = inputs[i]; | |||
inputs_mask[i] = false; | |||
} | |||
} | |||
auto fallback = [&] { return imperative::apply(op, unwrapped_inputs); }; | |||
if (auto apply_op = op.as<ApplyOp>()) { | |||
auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); | |||
if (iter != scalar_rules.end()) { | |||
return iter->second( | |||
apply_op->op(), unwrapped_inputs, {inputs_mask, nr_inputs}); | |||
} else { | |||
// TODO: repeat op | |||
return fallback(); | |||
} | |||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
if (create_tensor->shape().is_scalar()) { | |||
ValueShape scalar_shape = {1}; | |||
CreateTensor scalar_op( | |||
create_tensor->kind(), create_tensor->device(), | |||
create_tensor->dtype(), scalar_shape); | |||
return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])}; | |||
} else { | |||
return imperative::apply(op, inputs); | |||
} | |||
} else if (op.as<IsScalar>()) { | |||
return {BoolValue::make(inputs.as_array<1>()[0].is<ScalarValue>())}; | |||
mgb_assert(nr_inputs == 1); | |||
return {BoolValue::make(inputs_mask[0])}; | |||
} else if (op.is<Operator::IdentityLike>()) { | |||
bool is_scalar = inputs.as_array<1>()[0].is<ScalarValue>(); | |||
mgb_assert(nr_inputs == 1); | |||
bool is_scalar = inputs_mask[0]; | |||
auto outputs = fallback(); | |||
if (is_scalar) { | |||
return {ScalarValue::make(imperative::apply(op, unwrap_inputs(inputs))[0])}; | |||
} else { | |||
return imperative::apply(op, inputs); | |||
outputs[0] = ScalarValue::make(outputs[0]); | |||
} | |||
return outputs; | |||
} else { | |||
return imperative::apply(op, unwrap_inputs(inputs)); | |||
return fallback(); | |||
} | |||
}; | |||
@@ -0,0 +1,25 @@ | |||
/** | |||
* \file imperative/src/impl/transformations/tangent.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/transformations/tangent.h" | |||
namespace mgb { | |||
namespace imperative { | |||
ValueRefList TangentTransformation::apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) { | |||
if (auto* apply_op = op.as<ApplyOp>()) { | |||
} | |||
mgb_assert(false); | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -153,7 +153,7 @@ VarNodeArray TraceResult::dump( | |||
return output_nodes; | |||
} | |||
std::vector<ValueRef> TracingTransformation::apply_transformation( | |||
ValueRefList TracingTransformation::apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) { | |||
if (auto* op_value = op.as<ApplyOp>()) { | |||
SmallVector<ValueRef> unwrapped_inputs; | |||
@@ -180,11 +180,12 @@ std::vector<ValueRef> TracingTransformation::apply_transformation( | |||
} | |||
const_cast<OpDef&>(op_value->op()).set_scope(scopes_join); | |||
auto unwrapped_outputs = imperative::apply(op, unwrapped_inputs); | |||
std::vector<ValueRef> wrapped_outputs; | |||
ValueRefList wrapped_outputs(unwrapped_outputs.size()); | |||
SmallVector<size_t> output_ids; | |||
for (auto&& output : unwrapped_outputs) { | |||
for (size_t i = 0; i < unwrapped_outputs.size(); ++i) { | |||
auto&& output = unwrapped_outputs[i]; | |||
auto wrapped_output = record_var(output, false, VarKind::Internal); | |||
wrapped_outputs.push_back(wrapped_output); | |||
wrapped_outputs[i] = wrapped_output; | |||
output_ids.push_back(wrapped_output->id()); | |||
} | |||
m_seq.push_back({op_value->op().shared_from_this(), input_ids, output_ids}); | |||
@@ -375,6 +376,11 @@ void CompiledTransformation::compile() { | |||
return accessor; | |||
}; | |||
std::vector<VarAccessor> var_accessors(m_vars.size()); | |||
auto exc_setter = std::bind( | |||
&CompiledTransformation::set_exception, this, std::placeholders::_1); | |||
for (auto&& accessor : var_accessors) { | |||
accessor.exc_setter = exc_setter; | |||
} | |||
for (auto&& item : m_seq) { | |||
bool require_link = bool(item.op) && mm_io_ops.count(item.op->dyn_typeinfo()); | |||
VarNodeArray input_vars; | |||
@@ -509,8 +515,8 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { | |||
} | |||
} | |||
TracedValue::ref_t CompiledTransformation::trace_output(size_t id) { | |||
auto traced_value = TracedValue::make(id); | |||
auto CompiledTransformation::trace_output(size_t id) -> TracedValue::ref_t { | |||
auto traced_value = TracedValue::make(id, &m_vars[id], &m_var_accessors[id]); | |||
m_weak_values.push_back(traced_value); | |||
return traced_value; | |||
} | |||
@@ -520,64 +526,99 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() { | |||
return m_seq[m_pc++]; | |||
} | |||
std::vector<ValueRef> CompiledTransformation::apply_transformation( | |||
ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const { | |||
if (!m_shape) { | |||
trace_assert(m_accessor->shape_getter, "shape unreadable"); | |||
m_shape = ShapeValue::make(ValueShape::from(m_accessor->shape_getter())); | |||
} | |||
return m_shape; | |||
} | |||
DTypeValue::ref_t CompiledTransformation::TracedInfo::dtype() const { | |||
if (!m_dtype) { | |||
m_dtype = DTypeValue::make(m_var->dtype); | |||
} | |||
return m_dtype; | |||
} | |||
CompNodeValue::ref_t CompiledTransformation::TracedInfo::comp_node() const { | |||
if (!m_comp_node) { | |||
m_comp_node = CompNodeValue::make(m_var->device); | |||
} | |||
return m_comp_node; | |||
} | |||
auto CompiledTransformation::TracedInfo::accessor() const -> const VarAccessor& { | |||
return *m_accessor; | |||
} | |||
ValueRefList CompiledTransformation::apply_op( | |||
const ApplyOp& apply_op, Span<ValueRef> inputs) { | |||
auto& item = next_instruction(); | |||
trace_assert(inputs.size() == item.inputs.size(), "input size mismatch"); | |||
trace_assert(apply_op.op().is_same(*item.op), "operator mismatch"); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
trace_input(item.inputs[i], inputs[i]); | |||
} | |||
ValueRefList outputs(item.outputs.size()); | |||
for (size_t i = 0; i < item.outputs.size(); ++i) { | |||
outputs[i] = trace_output(item.outputs[i]); | |||
} | |||
return outputs; | |||
} | |||
ValueRefList CompiledTransformation::apply_get_attr( | |||
const GetAttr& get_attr, Span<ValueRef> inputs) { | |||
if (auto* traced_value = inputs[0].as<TracedValue>()) { | |||
ValueRef output; | |||
auto& var_accessor = traced_value->accessor(); | |||
switch (get_attr.attr()) { | |||
case GetAttr::Shape: | |||
output = traced_value->shape(); | |||
break; | |||
case GetAttr::Data: | |||
trace_assert(var_accessor.data_getter, "data unreadable"); | |||
output = DeviceValue::make(var_accessor.data_getter()); | |||
break; | |||
case GetAttr::Value: | |||
trace_assert(var_accessor.value_getter, "value unreadable"); | |||
output = HostValue::make(var_accessor.value_getter()); | |||
break; | |||
case GetAttr::DType: | |||
output = traced_value->dtype(); | |||
break; | |||
case GetAttr::Device: | |||
output = traced_value->comp_node(); | |||
default: | |||
break; | |||
} | |||
return {output}; | |||
} else { | |||
return imperative::apply(get_attr, inputs); | |||
} | |||
} | |||
ValueRefList CompiledTransformation::apply_create_tensor( | |||
const CreateTensor& create_tensor, Span<ValueRef> inputs) { | |||
if (create_tensor.kind() == CreateTensor::NoTrace) { | |||
return imperative::apply(create_tensor, inputs); | |||
} | |||
auto& item = next_instruction(); | |||
trace_assert(item.op == nullptr, "operator mismatch"); | |||
auto input_id = item.inputs[0]; | |||
auto output_id = item.outputs[0]; | |||
auto tensor = imperative::apply(create_tensor, inputs)[0]; | |||
trace_input(input_id, tensor); | |||
return {trace_output(output_id)}; | |||
} | |||
ValueRefList CompiledTransformation::apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) { | |||
if (auto* op_value = op.as<ApplyOp>()) { | |||
auto& item = next_instruction(); | |||
SmallVector<ValueRef> unwrapped_inputs; | |||
SmallVector<ValueRef> wrapped_inputs; | |||
trace_assert(inputs.size() == item.inputs.size(), "input size mismatch"); | |||
trace_assert(op_value->op().is_same(*item.op), "operator mismatch"); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
trace_input(item.inputs[i], inputs[i]); | |||
} | |||
std::vector<ValueRef> outputs; | |||
for (auto&& output_id : item.outputs) { | |||
outputs.push_back(trace_output(output_id)); | |||
} | |||
return outputs; | |||
return apply_op(*op_value, inputs); | |||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
if (create_tensor->kind() == CreateTensor::NoTrace) { | |||
return imperative::apply(op, inputs); | |||
} | |||
auto& item = next_instruction(); | |||
trace_assert(item.op == nullptr, "operator mismatch"); | |||
auto input_id = item.inputs[0]; | |||
auto output_id = item.outputs[0]; | |||
auto tensor = imperative::apply(op, inputs)[0]; | |||
trace_input(input_id, tensor); | |||
return {trace_output(output_id)}; | |||
return apply_create_tensor(*create_tensor, inputs); | |||
} else if (auto* get_attr = op.as<GetAttr>()) { | |||
if (auto* traced_value = inputs[0].as<TracedValue>()) { | |||
ValueRef output; | |||
auto& var = m_vars[traced_value->id()]; | |||
auto& var_accessor = m_var_accessors[traced_value->id()]; | |||
switch (get_attr->attr()) { | |||
case GetAttr::Shape: | |||
trace_assert(var_accessor.shape_getter, "shape unreadable"); | |||
output = ShapeValue::make( | |||
ValueShape::from(var_accessor.shape_getter())); | |||
break; | |||
case GetAttr::Data: | |||
trace_assert(var_accessor.data_getter, "data unreadable"); | |||
output = DeviceValue::make(var_accessor.data_getter()); | |||
break; | |||
case GetAttr::Value: | |||
trace_assert(var_accessor.value_getter, "value unreadable"); | |||
output = HostValue::make(var_accessor.value_getter()); | |||
break; | |||
case GetAttr::DType: | |||
output = DTypeValue::make(var.dtype); | |||
break; | |||
case GetAttr::Device: | |||
output = CompNodeValue::make(var.device); | |||
default: | |||
break; | |||
} | |||
return {output}; | |||
} else { | |||
return imperative::apply(op, inputs); | |||
} | |||
return apply_get_attr(*get_attr, inputs); | |||
} else if (auto* trace_mark_var = op.as<TraceMarkVar>()) { | |||
auto& item = next_instruction(); | |||
trace_assert(item.op == nullptr, "operator mismatch"); | |||
@@ -8,50 +8,58 @@ namespace mgb { | |||
namespace imperative { | |||
namespace { | |||
static thread_local size_t nr_watched_values = 0; | |||
static thread_local uint64_t nr_values = 0; | |||
static thread_local bool recording_values = false; | |||
static thread_local std::vector<ValueWeakRef> recorded_values; | |||
static /*thread_local*/ size_t nr_watched_values = 0; | |||
static /*thread_local*/ uint64_t nr_values = 0; | |||
static /*thread_local*/ bool recording_values = false; | |||
static /*thread_local*/ std::vector<ValueWeakRef> recorded_values; | |||
static WeakValueMap<uint64_t, ValueWeakRef> registered_values; | |||
} // namespace | |||
ValueRef::storage_t& ValueRef::storage() const { | |||
if (!m_storage) { | |||
if (mgb_likely(!m_storage->m_successor.m_storage)) { | |||
return m_storage; | |||
} | |||
if (auto& storage = m_storage->m_successor.m_storage) { | |||
while (storage->m_successor.m_storage) { | |||
storage = storage->m_successor.m_storage; | |||
} | |||
return storage; | |||
} else { | |||
return m_storage; | |||
while (m_storage->m_successor.m_storage) { | |||
m_storage = m_storage->m_successor.m_storage; | |||
} | |||
return m_storage; | |||
} | |||
const Value* ValueRef::as(size_t typecode) const { | |||
auto&& storage = this->storage(); | |||
if (storage->m_typecode != typecode) { | |||
return nullptr; | |||
} | |||
return static_cast<Value*>(storage.get()); | |||
} | |||
bool ValueRef::is(size_t typecode) const { | |||
return this->storage()->m_typecode == typecode; | |||
} | |||
TypedValueRef<DeviceValue> ValueRef::dev_tensor() const { | |||
return imperative::apply(GetAttr(GetAttr::Data), *this)[0].as_ref<DeviceValue>(); | |||
return imperative::apply(GetAttr(GetAttr::Data), *this)[0].cast_ref<DeviceValue>(); | |||
} | |||
TypedValueRef<HostValue> ValueRef::numpy() const { | |||
return imperative::apply(GetAttr(GetAttr::Value), *this)[0].as_ref<HostValue>(); | |||
return imperative::apply(GetAttr(GetAttr::Value), *this)[0].cast_ref<HostValue>(); | |||
} | |||
TypedValueRef<CompNodeValue> ValueRef::device() const { | |||
return imperative::apply(GetAttr(GetAttr::Device), *this)[0] | |||
.as_ref<CompNodeValue>(); | |||
.cast_ref<CompNodeValue>(); | |||
} | |||
TypedValueRef<ShapeValue> ValueRef::shape() const { | |||
return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].as_ref<ShapeValue>(); | |||
return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].cast_ref<ShapeValue>(); | |||
} | |||
TypedValueRef<DTypeValue> ValueRef::dtype() const { | |||
return imperative::apply(GetAttr(GetAttr::DType), *this)[0].as_ref<DTypeValue>(); | |||
return imperative::apply(GetAttr(GetAttr::DType), *this)[0].cast_ref<DTypeValue>(); | |||
} | |||
TypedValueRef<StringValue> ValueRef::name() const { | |||
return imperative::apply(GetName(), *this)[0].as_ref<StringValue>(); | |||
return imperative::apply(GetName(), *this)[0].cast_ref<StringValue>(); | |||
} | |||
bool ValueRef::is_scalar() const { | |||
@@ -75,13 +83,15 @@ void ValueRef::unwatch() const { | |||
} | |||
ValueRef ValueRef::unwrap() const { | |||
ValueRef value = *this; | |||
auto& context = Transformation::get_context(); | |||
for (size_t i = 0; i < context.next_transformation; ++i) { | |||
value = context.transformations[i]->unwrap(value); | |||
if (mgb_unlikely(context.next_transformation)) { | |||
ValueRef value = *this; | |||
for (size_t i = 0; i < context.next_transformation; ++i) { | |||
value = context.transformations[i]->unwrap(value); | |||
} | |||
return value; | |||
} | |||
mgb_assert(value); | |||
return value; | |||
return *this; | |||
} | |||
std::string ValueRef::to_string() const { | |||
@@ -101,13 +111,11 @@ std::string ValueRef::raw_type() const { | |||
return types[m_storage->m_typecode].name(); | |||
} | |||
uint64_t ValueRef::id() const { | |||
return m_storage ? m_storage->m_id : std::numeric_limits<uint64_t>::max(); | |||
} | |||
bool ValueRef::watching() const { | |||
auto storage = this->storage(); | |||
return storage && storage->m_watching; | |||
if (!m_storage) { | |||
return false; | |||
} | |||
return this->storage()->m_watching; | |||
} | |||
ValueRef ValueRef::make(ValueRef::storage_t storage) { | |||
@@ -186,5 +194,96 @@ void Value::try_rethrow() { | |||
} | |||
} | |||
inline void ValueRefList::init(size_t nr_elems) { | |||
m_size = nr_elems; | |||
if (m_size > 0) { | |||
if (m_size == 1) { | |||
m_data = inline_storage(); | |||
} else { | |||
auto& context = Transformation::get_context(); | |||
m_data = context.allocator.allocate(m_size); | |||
} | |||
for (size_t i = 0; i < m_size; ++i) { | |||
new (m_data + i) ValueRef(); | |||
} | |||
} else { | |||
m_data = nullptr; | |||
} | |||
} | |||
ValueRefList::ValueRefList(size_t nr_elems) { | |||
init(nr_elems); | |||
} | |||
ValueRefList::ValueRefList(std::initializer_list<ValueRef> values) | |||
: ValueRefList(values.begin(), values.end()) {} | |||
ValueRefList::ValueRefList(const ValueRefList& rhs) | |||
: ValueRefList(rhs.cbegin(), rhs.cend()) {} | |||
ValueRefList::ValueRefList(ValueRefList&& rhs) : ValueRefList() { | |||
m_size = rhs.m_size; | |||
if (rhs.m_data == rhs.inline_storage()) { | |||
m_data = inline_storage(); | |||
new (m_data) ValueRef(); | |||
m_data[0] = std::move(rhs.m_data[0]); | |||
} else { | |||
m_data = rhs.m_data; | |||
rhs.m_data = nullptr; | |||
rhs.m_size = 0; | |||
} | |||
} | |||
ValueRefList& ValueRefList::operator=(const ValueRefList& rhs) { | |||
if (this == &rhs) { | |||
return *this; | |||
} | |||
clear(); | |||
init(rhs.m_size); | |||
for (size_t i = 0; i < m_size; ++i) { | |||
m_data[i] = rhs.m_data[i]; | |||
} | |||
return *this; | |||
} | |||
ValueRefList& ValueRefList::operator=(ValueRefList&& rhs) { | |||
if (this == &rhs) { | |||
return *this; | |||
} | |||
clear(); | |||
if (rhs.m_data == rhs.inline_storage()) { | |||
m_data = inline_storage(); | |||
new (m_data) ValueRef(); | |||
m_data[0] = rhs.m_data[0]; | |||
m_size = 1; | |||
rhs.clear(); | |||
} else { | |||
m_data = rhs.m_data; | |||
m_size = rhs.m_size; | |||
rhs.m_data = nullptr; | |||
rhs.m_size = 0; | |||
} | |||
return *this; | |||
} | |||
ValueRefList::~ValueRefList() { | |||
clear(); | |||
} | |||
void ValueRefList::clear() { | |||
for (size_t i = 0; i < m_size; ++i) { | |||
m_data[i].~ValueRef(); | |||
} | |||
if (m_data) { | |||
if (m_size != 1) { | |||
Transformation::get_context().allocator.deallocate(m_data, m_size); | |||
} else { | |||
mgb_assert(m_data == inline_storage()); | |||
} | |||
} | |||
m_data = nullptr; | |||
m_size = 0; | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -24,8 +24,6 @@ namespace imperative { | |||
class GradKey; | |||
using GenericFunction = std::function<std::vector<ValueRef>(Span<ValueRef>)>; | |||
/** | |||
* \brief apply an OpDef to values | |||
* | |||
@@ -37,7 +35,7 @@ private: | |||
public: | |||
ApplyOp(const OpDef& op) : m_op(op) {} | |||
const OpDef& op() { return m_op; } | |||
const OpDef& op() const { return m_op; } | |||
std::string to_string() const override; | |||
}; | |||
@@ -106,7 +104,7 @@ public: | |||
* \param inputs contains host_storage and device_storage | |||
* \return Args unpacked args | |||
*/ | |||
Args parse(Span<ValueRef> inputs); | |||
Args parse(Span<ValueRef> inputs) const; | |||
Kind kind() const { return m_kind; } | |||
CompNode device() const { return m_device; } | |||
@@ -129,11 +127,11 @@ private: | |||
public: | |||
DTRCommand(Kind kind) : m_kind(kind) {} | |||
Kind kind() { return m_kind; } | |||
Kind kind() const { return m_kind; } | |||
std::string to_string() const override; | |||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { return {}; } | |||
ValueRefList fallback(Span<ValueRef> inputs) const override { return {}; } | |||
}; | |||
// deprecated | |||
@@ -141,9 +139,7 @@ class GetName final : public OperatorImpl<GetName, Operator::GetAttrLike> { | |||
public: | |||
std::string to_string() const override; | |||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||
return {ValueRef()}; | |||
} | |||
ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; } | |||
}; | |||
/** | |||
@@ -161,7 +157,7 @@ public: | |||
std::string to_string() const override; | |||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||
ValueRefList fallback(Span<ValueRef> inputs) const override { | |||
return {inputs.as_array<1>()[0]}; | |||
} | |||
}; | |||
@@ -23,7 +23,7 @@ namespace imperative { | |||
class GradKey; | |||
using GenericFunction = std::function<std::vector<ValueRef>(Span<ValueRef>)>; | |||
using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>; | |||
class ShapeValue final : public MixinValueImpl<ShapeValue, ValueShape> { | |||
public: | |||
@@ -97,6 +97,10 @@ public: | |||
ValueShape shape() const { return m_shape; } | |||
CompNode device() const { return m_storage.comp_node(); } | |||
HostTensorStorage storage() const { return m_storage; } | |||
DTypeScalar item() const { | |||
mgb_assert(m_shape.is_scalar()); | |||
return DTypeScalar::make_from_raw(m_dtype, m_storage.ptr()); | |||
} | |||
HostTensorND as_nd(bool allow_scalar = false) const; | |||
}; | |||
@@ -36,11 +36,11 @@ namespace imperative { | |||
* | |||
* \param op | |||
* \param inputs | |||
* \return std::vector<ValueRef> | |||
* \return ValueRefList | |||
*/ | |||
std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs); | |||
std::vector<ValueRef> apply(const OpDef& def, Span<ValueRef> inputs); | |||
std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs); | |||
ValueRefList apply(const Operator& op, Span<ValueRef> inputs); | |||
ValueRefList apply(const OpDef& def, Span<ValueRef> inputs); | |||
ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs); | |||
template <typename... TArgs> | |||
constexpr bool is_all_value_ref_v = | |||
@@ -49,7 +49,7 @@ constexpr bool is_all_value_ref_v = | |||
template <typename T, typename... TArgs> | |||
static auto apply(T&& op, TArgs&&... args) | |||
-> std::enable_if_t<is_all_value_ref_v<TArgs...>, std::vector<ValueRef>> { | |||
-> std::enable_if_t<is_all_value_ref_v<TArgs...>, ValueRefList> { | |||
ValueRef args_arr[sizeof...(TArgs)] = {std::forward<TArgs&&>(args)...}; | |||
return imperative::apply( | |||
std::forward<T&&>(op), | |||
@@ -63,7 +63,7 @@ static auto apply(T&& op, TContainer&& container) -> std::enable_if_t< | |||
ValueRef> && | |||
std::is_same_v<decltype(container.size()), size_t> && | |||
!std::is_same_v<std::decay_t<TContainer>, Span<ValueRef>>, | |||
std::vector<ValueRef>> { | |||
ValueRefList> { | |||
return imperative::apply( | |||
std::forward<T&&>(op), Span<ValueRef>(container.data(), container.size())); | |||
} | |||
@@ -25,6 +25,8 @@ | |||
namespace mgb { | |||
namespace imperative { | |||
using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>; | |||
/** | |||
* \brief base class for all operators | |||
* | |||
@@ -49,25 +51,24 @@ public: | |||
Kind kind() const { return m_kind; } | |||
template <typename U> | |||
U* as() const { | |||
const U* as() const { | |||
if (m_typecode != U::TYPE_CODE) { | |||
return nullptr; | |||
} | |||
return static_cast<U*>(const_cast<Operator*>(this)); | |||
return static_cast<const U*>(this); | |||
} | |||
template <typename U> | |||
bool is() const { | |||
return as<U>() != nullptr; | |||
return m_typecode == U::TYPE_CODE; | |||
} | |||
template <Kind kKind> | |||
bool is() const { | |||
return kind() == kKind; | |||
} | |||
template <typename U> | |||
U& cast() const { | |||
U* ptr = as<U>(); | |||
mgb_assert(ptr); | |||
return *ptr; | |||
const U& cast() const { | |||
mgb_assert(m_typecode == U::TYPE_CODE); | |||
return static_cast<const U&>(*this); | |||
} | |||
virtual std::string to_string() const = 0; | |||
@@ -77,9 +78,9 @@ public: | |||
* implementation. | |||
* | |||
* \param inputs | |||
* \return std::vector<ValueRef> | |||
* \return ValueRefList | |||
*/ | |||
virtual std::vector<ValueRef> fallback(Span<ValueRef> inputs) const; | |||
virtual ValueRefList fallback(Span<ValueRef> inputs) const; | |||
std::type_index type() const { return registered_types()[m_typecode]; } | |||
@@ -123,7 +123,6 @@ public: | |||
template <typename T, typename... TArgs> | |||
static uint64_t record(TArgs&&... args) { | |||
auto& profiler = get_instance(); | |||
// auto& mem_pool = get_mem_pool<T>(); | |||
if constexpr (sm_debug) { | |||
Status expected = Running; | |||
mgb_assert(profiler.m_status.compare_exchange_strong(expected, Recording)); | |||
@@ -18,6 +18,7 @@ | |||
#include "megbrain/common.h" | |||
#include "megbrain/imperative/subgraph.h" | |||
#include "megbrain/imperative/utils/allocator.h" | |||
#include "megbrain/imperative/utils/local_ptr.h" | |||
#include "megbrain/imperative/utils/span.h" | |||
@@ -25,6 +26,7 @@ namespace mgb { | |||
namespace imperative { | |||
class ValueRef; | |||
class ValueRefList; | |||
class Operator; | |||
class Transformation; | |||
@@ -43,6 +45,7 @@ struct TransformationContext { | |||
// TODO: deprecate TransformationGuard, let next_transformation == frames.size() | |||
size_t next_transformation = 0; | |||
std::vector<TransformationFrame> frames; | |||
ForwardAllocator<ValueRef> allocator; | |||
}; | |||
/** | |||
@@ -86,9 +89,9 @@ public: | |||
* | |||
* \param op | |||
* \param inputs | |||
* \return std::vector<ValueRef> | |||
* \return ValueRefList | |||
*/ | |||
virtual std::vector<ValueRef> apply_transformation( | |||
virtual ValueRefList apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) = 0; | |||
virtual ValueRef unwrap(ValueRef value) = 0; | |||
@@ -187,11 +190,12 @@ public: | |||
std::swap(context.transformations, current_context.transformations); | |||
std::swap(context.scopes, current_context.scopes); | |||
std::swap(context.next_transformation, current_context.next_transformation); | |||
std::swap(context.allocator, current_context.allocator); | |||
} | |||
static TransformationContext& get_context(); | |||
friend std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs); | |||
friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs); | |||
friend class ValueRef; | |||
}; | |||
@@ -23,16 +23,38 @@ public: | |||
using Handle = interpreter::Interpreter::Handle; | |||
using Channel = interpreter::Interpreter::Channel; | |||
class RAIIHandle : public NonCopyableObj { | |||
private: | |||
Handle m_handle = nullptr; | |||
Channel* m_channel = nullptr; | |||
public: | |||
RAIIHandle(Handle handle, Channel* channel) | |||
: m_handle(handle), m_channel(channel) {} | |||
~RAIIHandle() { m_channel->del(m_handle); } | |||
Handle handle() const { return m_handle; } | |||
Channel* channel() const { return m_channel; } | |||
}; | |||
private: | |||
std::shared_ptr<Handle> m_handle = nullptr; | |||
LocalPtr<RAIIHandle> m_handle; | |||
std::string m_name; | |||
mutable DTypeValue::ref_t m_dtype; | |||
mutable CompNodeValue::ref_t m_comp_node; | |||
mutable ShapeValue::ref_t m_shape; | |||
public: | |||
InterpreterInfo() = default; | |||
InterpreterInfo(std::shared_ptr<Handle> handle, std::string name = {}) | |||
InterpreterInfo(LocalPtr<RAIIHandle> handle, std::string name = {}) | |||
: m_handle(handle), m_name(name) {} | |||
std::shared_ptr<Handle> handle() const { return m_handle; } | |||
const LocalPtr<RAIIHandle>& handle() const { return m_handle; } | |||
DTypeValue::ref_t dtype() const; | |||
CompNodeValue::ref_t comp_node() const; | |||
ShapeValue::ref_t shape() const; | |||
std::string name() const { return m_name; } | |||
}; | |||
@@ -60,6 +82,7 @@ class InterpreterTransformation final : public Transformation { | |||
public: | |||
using Interpreter = interpreter::Interpreter; | |||
using Handle = Interpreter::Handle; | |||
using SharedHandle = LocalPtr<InterpreterInfo::RAIIHandle>; | |||
using Channel = Interpreter::Channel; | |||
private: | |||
@@ -71,7 +94,14 @@ public: | |||
Channel* channel() { return m_channel.get(); } | |||
std::vector<ValueRef> apply_transformation( | |||
ValueRefList apply_op(const ApplyOp& apply_op, Span<ValueRef> inputs); | |||
ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs); | |||
ValueRefList apply_create_tensor( | |||
const CreateTensor& create_tensor, Span<ValueRef> inputs); | |||
ValueRefList apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) override; | |||
ValueRef unwrap(ValueRef value) override { | |||
@@ -81,14 +111,8 @@ public: | |||
std::string name() const override { return "InterpreterTransformation"; } | |||
std::shared_ptr<Handle> share_handle(Handle handle) { | |||
return std::shared_ptr<Handle>( | |||
new Handle(handle), [channel = m_channel.get()](Handle* ptr) { | |||
if (ptr) { | |||
channel->del(*ptr); | |||
delete ptr; | |||
} | |||
}); | |||
SharedHandle share_handle(Handle handle) { | |||
return SharedHandle::make(handle, m_channel.get()); | |||
} | |||
}; | |||
@@ -34,9 +34,7 @@ struct BackwardGraphWithClosure { | |||
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph, | |||
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs); | |||
void operator()( | |||
std::vector<ValueRef> grads, | |||
std::function<void(size_t, ValueRef)> receiver); | |||
void operator()(ValueRefList grads, std::function<void(size_t, ValueRef)> receiver); | |||
bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; } | |||
@@ -50,12 +48,11 @@ struct BackwardGraphWithClosure { | |||
struct CustomBackward; | |||
using GradRuleFn = | |||
std::function<std::vector<ValueRef>(Span<ValueRef> inputs, CustomBackward&)>; | |||
using GradRuleFn = std::function<ValueRefList(Span<ValueRef> inputs, CustomBackward&)>; | |||
struct CustomBackward { | |||
using BackwardFn = std::function<std::vector<ValueRef>(Span<ValueRef>)>; | |||
using BackwardRule = std::function<std::optional<std::vector<ValueRef>>( | |||
using BackwardFn = std::function<ValueRefList(Span<ValueRef>)>; | |||
using BackwardRule = std::function<std::optional<ValueRefList>( | |||
const OpDef&, Span<ValueRef>, Span<bool>, CustomBackward&)>; | |||
BackwardFn m_backward; | |||
SmallVector<bool, 8> m_input_has_grad; | |||
@@ -65,9 +62,7 @@ struct CustomBackward { | |||
SmallVector<OutputAttr> m_output_attrs; | |||
public: | |||
void operator()( | |||
std::vector<ValueRef> grads, | |||
std::function<void(size_t, ValueRef)> receiver); | |||
void operator()(ValueRefList grads, std::function<void(size_t, ValueRef)> receiver); | |||
bool input_has_grad(size_t i) { return m_input_has_grad[i]; } | |||
bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; } | |||
@@ -188,7 +183,7 @@ public: | |||
std::string to_string() const override; | |||
bool has_key(std::shared_ptr<GradKey> key) const { return m_key == key; } | |||
bool has_key(const std::shared_ptr<GradKey>& key) const { return m_key == key; } | |||
const GradSlotPtr& slot_for(std::shared_ptr<GradKey> key) const { | |||
mgb_assert(m_key == key); | |||
@@ -287,7 +282,7 @@ public: | |||
return false; | |||
} | |||
std::vector<ValueRef> apply_transformation( | |||
ValueRefList apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) override; | |||
ValueRef unwrap(ValueRef value) override { | |||
@@ -314,7 +309,7 @@ private: | |||
public: | |||
std::string to_string() const override { return "DetachValue"; } | |||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||
ValueRefList fallback(Span<ValueRef> inputs) const override { | |||
return {inputs.as_array<1>()[0]}; | |||
} | |||
}; | |||
@@ -325,7 +320,7 @@ private: | |||
public: | |||
AttachGrad(std::shared_ptr<GradKey> key) : m_key(key) {} | |||
std::shared_ptr<GradKey> key() { return m_key; } | |||
std::shared_ptr<GradKey> key() const { return m_key; } | |||
std::string to_string() const override { | |||
return ssprintf("AttachGradValue{key=%s}", m_key->name().c_str()); | |||
@@ -339,7 +334,7 @@ private: | |||
public: | |||
GradBackward(std::shared_ptr<GradKey> key) : m_key(key) {} | |||
std::shared_ptr<GradKey> key() { return m_key; } | |||
std::shared_ptr<GradKey> key() const { return m_key; } | |||
std::string to_string() const override { | |||
return ssprintf("GradBackwardValue{key=%s}", m_key->name().c_str()); | |||
@@ -352,13 +347,13 @@ private: | |||
public: | |||
IsAttachedTo(std::shared_ptr<GradKey> key) : m_key(key) {} | |||
std::shared_ptr<GradKey> key() { return m_key; } | |||
std::shared_ptr<GradKey> key() const { return m_key; } | |||
std::string to_string() const override { | |||
return ssprintf("IsAttachedToValue{key=%s}", m_key->name().c_str()); | |||
} | |||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||
ValueRefList fallback(Span<ValueRef> inputs) const override { | |||
return {BoolValue::make(false)}; | |||
} | |||
}; | |||
@@ -373,9 +368,9 @@ public: | |||
SetGrad(std::shared_ptr<GradKey> key, GenericFunction grad_fn, size_t nr_inputs) | |||
: m_key(key), m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | |||
GenericFunction grad_fn() { return m_grad_fn; } | |||
GenericFunction grad_fn() const { return m_grad_fn; } | |||
size_t nr_inputs() { return m_nr_inputs; } | |||
size_t nr_inputs() const { return m_nr_inputs; } | |||
std::string to_string() const override { | |||
return ssprintf("SetGradValue{key=%s}", m_key->name().c_str()); | |||
@@ -388,9 +383,7 @@ public: | |||
std::string to_string() const override { return ssprintf("GetGradKeyValue{}"); } | |||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||
return {ValueRef()}; | |||
} | |||
ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; } | |||
}; | |||
class GetBackwardColsure | |||
@@ -401,7 +394,7 @@ private: | |||
public: | |||
GetBackwardColsure(std::shared_ptr<GradKey> key) : m_key(key) {} | |||
std::shared_ptr<GradKey> key() { return m_key; } | |||
std::shared_ptr<GradKey> key() const { return m_key; } | |||
std::string to_string() const override { | |||
return ssprintf("GetBackwardClosure{key=%s}", m_key->name().c_str()); | |||
@@ -81,7 +81,7 @@ public: | |||
ComputingGraph::Options& options() { return m_graph->options(); } | |||
std::vector<ValueRef> apply_transformation( | |||
ValueRefList apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) override; | |||
ValueRef unwrap(ValueRef value) override { | |||
@@ -11,6 +11,7 @@ | |||
#pragma once | |||
#include "megbrain/imperative/basic_operators.h" | |||
#include "megbrain/imperative/dispatch.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
@@ -45,8 +46,10 @@ public: | |||
*/ | |||
class ScalarTransformation final : public Transformation { | |||
private: | |||
ShapeValue::ref_t m_empty_shape; // [] | |||
public: | |||
std::vector<ValueRef> apply_transformation( | |||
ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs); | |||
ValueRefList apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) override; | |||
ValueRef unwrap(ValueRef value) override { | |||
@@ -50,7 +50,7 @@ private: | |||
public: | |||
SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {} | |||
std::vector<ValueRef> apply_transformation( | |||
ValueRefList apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) override { | |||
if (auto* apply_op = op.as<ApplyOp>()) { | |||
SmallVector<VarNode*> input_nodes; | |||
@@ -58,9 +58,9 @@ public: | |||
input_nodes.push_back(input.cast<SymbolValue>().node()); | |||
} | |||
auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); | |||
std::vector<ValueRef> outputs; | |||
for (auto&& output_node : output_nodes) { | |||
outputs.push_back(SymbolValue::make(output_node)); | |||
ValueRefList outputs(output_nodes.size()); | |||
for (size_t i = 0; i < output_nodes.size(); ++i) { | |||
outputs[i] = SymbolValue::make(output_nodes[i]); | |||
} | |||
return outputs; | |||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
@@ -0,0 +1,36 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/grad.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/imperative/basic_operators.h" | |||
#include "megbrain/imperative/operator.h" | |||
#include "megbrain/imperative/transformation.h" | |||
#include "megbrain/imperative/value.h" | |||
namespace mgb::imperative { | |||
struct TangentInfo { | |||
ValueRef value; | |||
ValueRef tangent; | |||
}; | |||
class TangentTransformation final : public Transformation { | |||
public: | |||
ValueRefList apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) override; | |||
ValueRef unwrap(ValueRef value) override { mgb_assert(false); } | |||
std::string name() const override { return "Tangent"; } | |||
}; | |||
} // namespace mgb::imperative |
@@ -126,25 +126,6 @@ public: | |||
void on_unwatch() override { value().unwatch(); } | |||
}; | |||
class TracedInfo { | |||
private: | |||
size_t m_id = 0; | |||
public: | |||
TracedInfo() = default; | |||
TracedInfo(size_t id) : m_id(id) {} | |||
size_t id() const { return m_id; } | |||
}; | |||
class TracedValue final : public MixinValueImpl<TracedValue, TracedInfo> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const override { | |||
return ssprintf("TracedValue{\"id\"=%zu}", id()); | |||
} | |||
}; | |||
/** | |||
* \brief trace operation sequence to TraceResult | |||
* | |||
@@ -202,7 +183,7 @@ public: | |||
return value; | |||
} | |||
std::vector<ValueRef> apply_transformation( | |||
ValueRefList apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) override; | |||
ValueRef unwrap(ValueRef value) override { | |||
@@ -248,6 +229,40 @@ public: | |||
std::function<DeviceTensorND()> data_getter; | |||
std::function<HostTensorND()> value_getter; | |||
std::function<void(DeviceTensorND)> data_setter; | |||
std::function<void(std::exception_ptr)> exc_setter; | |||
}; | |||
class TracedInfo { | |||
private: | |||
size_t m_id = 0; | |||
VarInfo* m_var = nullptr; | |||
VarAccessor* m_accessor = nullptr; | |||
mutable ShapeValue::ref_t m_shape; | |||
mutable DTypeValue::ref_t m_dtype; | |||
mutable CompNodeValue::ref_t m_comp_node; | |||
public: | |||
TracedInfo() = default; | |||
TracedInfo(size_t id, VarInfo* var, VarAccessor* accessor) | |||
: m_id(id), m_var(var), m_accessor(accessor) {} | |||
size_t id() const { return m_id; } | |||
ShapeValue::ref_t shape() const; | |||
DTypeValue::ref_t dtype() const; | |||
CompNodeValue::ref_t comp_node() const; | |||
const VarAccessor& accessor() const; | |||
void set_exception(std::exception_ptr exc) const { | |||
m_accessor->exc_setter(exc); | |||
} | |||
}; | |||
class TracedValue final : public MixinValueImpl<TracedValue, TracedInfo> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const override { | |||
return ssprintf("TracedValue{\"id\"=%zu}", id()); | |||
} | |||
}; | |||
private: | |||
@@ -319,7 +334,14 @@ public: | |||
TraceResult::SeqItem& next_instruction(); | |||
std::vector<ValueRef> apply_transformation( | |||
ValueRefList apply_op(const ApplyOp& apply_op, Span<ValueRef> inputs); | |||
ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs); | |||
ValueRefList apply_create_tensor( | |||
const CreateTensor& create_tensor, Span<ValueRef> inputs); | |||
ValueRefList apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) override; | |||
void on_unregister() noexcept override; | |||
@@ -36,12 +36,12 @@ private: | |||
public: | |||
Allocator(pool_type* pool) : m_pool(pool) {} | |||
T* allocate(size_type n) { | |||
pointer allocate(size_type n) { | |||
mgb_assert(n == 1); | |||
return m_pool->alloc(sizeof(T)); | |||
} | |||
void deallocate(pointer* p, size_type n) { | |||
void deallocate(pointer p, size_type n) { | |||
mgb_assert(n == 1); | |||
m_pool->free(p); | |||
} | |||
@@ -68,4 +68,114 @@ public: | |||
bool operator!=(const ThreadLocalAllocatorAdapter& rhs) const { return false; } | |||
}; | |||
} // namespace mgb::imperative | |||
template <typename T> | |||
class ForwardAllocator { | |||
public: | |||
using value_type = T; | |||
using size_type = std::size_t; | |||
using pointer = T*; | |||
static constexpr size_t alignment = alignof(T); | |||
static constexpr size_t element_offset = | |||
sizeof(T) + | |||
((sizeof(T) % alignment) ? 0 : (alignment - sizeof(T) % alignment)); | |||
private: | |||
struct Block { | |||
std::unique_ptr<std::byte[]> data; | |||
size_t size = 0; | |||
size_t capacity = 0; | |||
T* allocate(size_type n) { | |||
static_assert(element_offset > std::max(alignment, sizeof(T))); | |||
size_t begin = size; | |||
size_t end = begin + element_offset * n; | |||
if (end > capacity) { | |||
return nullptr; | |||
} | |||
size = end; | |||
return reinterpret_cast<T*>(data.get() + begin); | |||
} | |||
void reset() { size = 0; } | |||
}; | |||
std::vector<Block> m_used; | |||
std::optional<Block> m_current; | |||
size_t block_size = 16 * 1024 * 1024; | |||
size_t nr_allocated = 0; | |||
private: | |||
Block allocate_block() { | |||
block_size *= 2; | |||
return Block{std::make_unique<std::byte[]>(block_size), 0, block_size}; | |||
} | |||
public: | |||
pointer allocate(size_type n) { | |||
if (!m_current) { | |||
m_current.emplace(allocate_block()); | |||
} | |||
pointer pointer = m_current->allocate(n); | |||
while (pointer == nullptr) { | |||
m_used.push_back(allocate_block()); | |||
std::swap(m_used.back(), *m_current); | |||
pointer = m_current->allocate(n); | |||
} | |||
nr_allocated++; | |||
return pointer; | |||
} | |||
void deallocate(pointer p, size_type n) { | |||
mgb_assert(nr_allocated > 0); | |||
nr_allocated--; | |||
} | |||
void clear() { | |||
if (mgb_likely(m_used.empty())) { | |||
// fastpath | |||
if (m_current) { | |||
m_current->reset(); | |||
} | |||
} else { | |||
// trim | |||
*m_current = allocate_block(); | |||
m_used.clear(); | |||
} | |||
mgb_assert(nr_allocated == 0); | |||
} | |||
bool operator==(const ForwardAllocator& rhs) const { return &rhs == this; } | |||
bool operator!=(const ForwardAllocator& rhs) const { return &rhs != this; } | |||
}; | |||
template <typename T, template <typename> typename TAllocator> | |||
class ProxyAllocator { | |||
public: | |||
using value_type = T; | |||
using size_type = typename TAllocator<T>::size_type; | |||
using pointer = typename TAllocator<T>::pointer; | |||
private: | |||
TAllocator<T>* m_impl; | |||
public: | |||
T* allocate(size_type n) { return m_impl->allocate(n); } | |||
void deallocate(pointer* p, size_type n) { return m_impl->deallocate(p, n); } | |||
bool operator==(const ProxyAllocator<T, TAllocator>& rhs) const { | |||
if (m_impl == rhs.m_impl) { | |||
return true; | |||
} else if (bool(m_impl) ^ bool(rhs.m_impl)) { | |||
return false; | |||
} else { | |||
return *m_impl == *rhs.m_impl; | |||
} | |||
} | |||
bool operator!=(const ProxyAllocator<T, TAllocator>& rhs) const { | |||
return !((*this) == rhs); | |||
} | |||
}; | |||
} // namespace mgb::imperative |
@@ -16,6 +16,8 @@ | |||
#include "megbrain/imperative/utils/mempool.h" | |||
#include "megbrain/utils/metahelper.h" | |||
#define MGB_FAT_LOCAL_PTR 0 | |||
namespace mgb::imperative { | |||
template <typename T> | |||
@@ -52,6 +54,8 @@ private: | |||
} | |||
} | |||
size_t ref_count() const { return m_ref_count; } | |||
template <typename U> | |||
friend class LocalPtr; | |||
@@ -88,14 +92,24 @@ public: | |||
using storage_t = LocalPtrStorage<T>; | |||
using pool_t = MemPool<storage_t>; | |||
using weak_type = LocalWeakPtr<T>; | |||
using pointer_t = T*; | |||
private: | |||
storage_t* m_storage = nullptr; | |||
#if MGB_FAT_LOCAL_PTR | |||
pointer_t m_pointer = nullptr; | |||
#endif | |||
// (m_storage == nullptr) == (m_pointer == nullptr) | |||
void emplace(storage_t* ptr) { | |||
if (ptr) { | |||
ptr->inc_ref(); | |||
m_storage = ptr; | |||
#if MGB_FAT_LOCAL_PTR | |||
m_pointer = ptr->m_pointer; | |||
#endif | |||
} | |||
} | |||
@@ -103,8 +117,22 @@ private: | |||
public: | |||
LocalPtr() = default; | |||
LocalPtr(const LocalPtr& rhs) { (*this) = rhs; } | |||
LocalPtr(LocalPtr&& rhs) { (*this) = std::move(rhs); } | |||
LocalPtr(const LocalPtr& rhs) { | |||
auto storage = rhs.m_storage; | |||
if (storage) { | |||
storage->inc_ref(); | |||
} | |||
m_storage = storage; | |||
#if MGB_FAT_LOCAL_PTR | |||
m_pointer = rhs.m_pointer; | |||
#endif | |||
} | |||
LocalPtr(LocalPtr&& rhs) { | |||
std::swap(m_storage, rhs.m_storage); | |||
#if MGB_FAT_LOCAL_PTR | |||
std::swap(m_pointer, rhs.m_pointer); | |||
#endif | |||
} | |||
LocalPtr& operator=(const LocalPtr& rhs) { | |||
if (this == &rhs) { | |||
return *this; | |||
@@ -115,9 +143,11 @@ public: | |||
} | |||
if (m_storage) { | |||
m_storage->dec_ref(); | |||
// rhs.m_storage may be invalid here | |||
} | |||
m_storage = storage; | |||
#if MGB_FAT_LOCAL_PTR | |||
m_pointer = rhs.m_pointer; | |||
#endif | |||
return *this; | |||
} | |||
LocalPtr& operator=(LocalPtr&& rhs) { | |||
@@ -125,6 +155,9 @@ public: | |||
return *this; | |||
} | |||
std::swap(m_storage, rhs.m_storage); | |||
#if MGB_FAT_LOCAL_PTR | |||
std::swap(m_pointer, rhs.m_pointer); | |||
#endif | |||
rhs.reset(); | |||
return *this; | |||
} | |||
@@ -186,10 +219,11 @@ public: | |||
T& operator*() const { return *get(); } | |||
T* get() const { | |||
if ((!m_storage) || !m_storage->m_pointer) { | |||
return nullptr; | |||
} | |||
return m_storage->m_pointer; | |||
#if MGB_FAT_LOCAL_PTR | |||
return m_pointer; | |||
#else | |||
return m_storage ? m_storage->m_pointer : nullptr; | |||
#endif | |||
} | |||
T* operator->() const { return get(); } | |||
@@ -202,6 +236,9 @@ public: | |||
if (m_storage) { | |||
m_storage->dec_ref(); | |||
m_storage = nullptr; | |||
#if MGB_FAT_LOCAL_PTR | |||
m_pointer = nullptr; | |||
#endif | |||
} | |||
} | |||
@@ -49,8 +49,8 @@ public: | |||
instance = std::make_unique<MemPool<T>>(); | |||
sm_instance = instance.get(); | |||
} | |||
mgb_assert(sm_instance); | |||
} | |||
return *sm_instance; | |||
} | |||
}; | |||
@@ -62,9 +62,9 @@ std::unordered_map<std::thread::id, std::unique_ptr<MemPool<T>>> | |||
MemPoolUtils<T>::sm_instances; | |||
template <typename T> | |||
thread_local MemPool<T>* MemPoolUtils<T>::tm_instance; | |||
thread_local MemPool<T>* MemPoolUtils<T>::tm_instance = nullptr; | |||
template <typename T> | |||
MemPool<T>* MemPoolUtils<T>::sm_instance; | |||
MemPool<T>* MemPoolUtils<T>::sm_instance = nullptr; | |||
} // namespace mgb::imperative | |||
} // namespace mgb::imperative |
@@ -95,6 +95,8 @@ struct ValueShape { | |||
} | |||
return true; | |||
} | |||
bool operator!=(const ValueShape& rhs) const { return !operator==(rhs); } | |||
}; | |||
static_assert(sizeof(size_t) >= sizeof(int)); | |||
@@ -47,6 +47,17 @@ class StringValue; | |||
class Operator; | |||
class ValueRefList; | |||
template <typename T> | |||
class Type { | |||
private: | |||
const size_t m_code = T::TYPE_CODE; | |||
public: | |||
inline size_t code() const { return m_code; } | |||
}; | |||
/** | |||
* \brief an smart reference of value | |||
* | |||
@@ -64,8 +75,9 @@ public: | |||
protected: | |||
mutable storage_t m_storage; | |||
size_t m_id = std::numeric_limits<size_t>::max(); | |||
ValueRef(storage_t storage) { m_storage = storage; } | |||
inline ValueRef(storage_t storage); | |||
private: | |||
/** | |||
@@ -75,6 +87,10 @@ private: | |||
*/ | |||
storage_t& storage() const; | |||
const Value* as(size_t typecode) const; | |||
bool is(size_t typecode) const; | |||
public: | |||
ValueRef() = default; | |||
@@ -86,7 +102,7 @@ public: | |||
* \return false if empty or type of value is not TValue | |||
*/ | |||
template <typename TValue> | |||
bool is() const; | |||
inline bool is(Type<TValue> type = {}) const; | |||
/** | |||
* \brief try cast value as target type | |||
@@ -95,7 +111,7 @@ public: | |||
* \return TValue* raw pointer if success, otherwise nullptr | |||
*/ | |||
template <typename TValue> | |||
const TValue* as() const; | |||
inline const TValue* as(Type<TValue> type = {}) const; | |||
/** | |||
* \brief cast value to target type | |||
@@ -104,7 +120,7 @@ public: | |||
* \return TValue& reference of value | |||
*/ | |||
template <typename TValue> | |||
const TValue& cast() const; | |||
inline const TValue& cast(Type<TValue> type = {}) const; | |||
/** | |||
* \brief like as(), but returns TypedValueRef instead | |||
@@ -113,7 +129,13 @@ public: | |||
* \return TypedValueRef<TValue> reference if success, otherwise empty reference | |||
*/ | |||
template <typename TValue> | |||
inline TypedValueRef<TValue> as_ref() const; | |||
inline TypedValueRef<TValue> as_ref(Type<TValue> type = {}) const; | |||
template <typename TValue> | |||
inline TypedValueRef<TValue> cast_ref(Type<TValue> type = {}) const; | |||
template <typename TValue> | |||
void on_cast_failure() const; | |||
operator bool() const { return bool(m_storage); } | |||
@@ -132,7 +154,7 @@ public: | |||
ValueRef unwrap() const; | |||
std::string to_string() const; | |||
std::string raw_type() const; | |||
uint64_t id() const; | |||
uint64_t id() const { return m_id; } | |||
size_t hash() const { return id(); } | |||
static ValueRef make(storage_t storage); | |||
@@ -144,7 +166,7 @@ public: | |||
friend class TypedValueRef; | |||
template <typename T> | |||
friend class ValueImpl; | |||
friend std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs); | |||
friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs); | |||
}; | |||
template <> | |||
@@ -244,7 +266,7 @@ public: | |||
using ref_t = TypedValueRef<T>; | |||
using weak_ref_t = TypedValueWeakRef<T>; | |||
static inline size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); | |||
static inline const size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); | |||
/** | |||
* \brief helper function for construct a value | |||
@@ -254,7 +276,7 @@ public: | |||
* \return TypedValueRef<T> reference of value | |||
*/ | |||
template <typename... TArgs> | |||
static TypedValueRef<T> make(TArgs&&... args) { | |||
static MGB_NOINLINE TypedValueRef<T> make(TArgs&&... args) { | |||
static_assert(std::is_final_v<T>); | |||
return ValueRef::make(LocalPtr<Value>::make<T>(std::forward<TArgs&&>(args)...)); | |||
} | |||
@@ -279,46 +301,60 @@ public: | |||
bool eq(const TMixin& value) const { return ((const TMixin&)*this) == value; } | |||
}; | |||
inline ValueRef::ValueRef(storage_t storage) { | |||
// mgb_assert(storage); | |||
m_storage = storage; | |||
m_id = m_storage->m_id; | |||
} | |||
template <typename TValue> | |||
const TValue* ValueRef::as() const { | |||
inline const TValue* ValueRef::as(Type<TValue> type) const { | |||
static_assert(std::is_base_of_v<ValueImpl<TValue>, TValue>); | |||
auto storage = this->storage(); | |||
if (!storage) { | |||
return nullptr; | |||
} | |||
if (storage->m_typecode != TValue::TYPE_CODE) { | |||
return nullptr; | |||
} | |||
return static_cast<TValue*>(storage.get()); | |||
return static_cast<const TValue*>(as(type.code())); | |||
} | |||
template <typename TValue> | |||
const TValue& ValueRef::cast() const { | |||
auto* ptr = as<TValue>(); | |||
if (!ptr) { | |||
// if this is ErrorValue, rethrow directly | |||
storage()->try_rethrow(); | |||
mgb_assert( | |||
ptr, "expect type %s, got %s", typeid(TValue).name(), | |||
to_string().c_str()); | |||
inline const TValue& ValueRef::cast(Type<TValue> type) const { | |||
auto* ptr = as<TValue>(type); | |||
if (mgb_unlikely(!ptr)) { | |||
on_cast_failure<TValue>(); | |||
} | |||
return *ptr; | |||
return static_cast<const TValue&>(*ptr); | |||
} | |||
template <typename TValue> | |||
inline bool ValueRef::is(Type<TValue> type) const { | |||
return is(type.code()); | |||
} | |||
template <typename TValue> | |||
bool ValueRef::is() const { | |||
auto* ptr = as<TValue>(); | |||
return ptr != nullptr; | |||
inline TypedValueRef<TValue> ValueRef::as_ref(Type<TValue> type) const { | |||
if (!is<TValue>(type)) { | |||
return {}; | |||
} | |||
return TypedValueRef<TValue>(*this); | |||
} | |||
template <typename TValue> | |||
TypedValueRef<TValue> ValueRef::as_ref() const { | |||
if (!is<TValue>()) { | |||
inline TypedValueRef<TValue> ValueRef::cast_ref(Type<TValue> type) const { | |||
if (!m_storage) { | |||
return {}; | |||
} | |||
if (mgb_unlikely(!is<TValue>(type))) { | |||
on_cast_failure<TValue>(); | |||
} | |||
return TypedValueRef<TValue>(*this); | |||
} | |||
template <typename TValue> | |||
void ValueRef::on_cast_failure() const { | |||
// if this is ErrorValue, rethrow directly | |||
storage()->try_rethrow(); | |||
mgb_assert( | |||
storage()->m_typecode != TValue::TYPE_CODE, "expect type %s, got %s", | |||
typeid(TValue).name(), to_string().c_str()); | |||
} | |||
/** | |||
* \brief ValueRef with concrete type, convenient for dereference | |||
* | |||
@@ -361,11 +397,87 @@ private: | |||
public: | |||
TypedValueWeakRef(ValueRef value) : ValueWeakRef(value) {} | |||
TypedValueWeakRef(ValueWeakRef value) : ValueWeakRef(value) {} | |||
TypedValueRef<T> lock() { return ValueWeakRef::lock().template as_ref<T>(); } | |||
TypedValueRef<T> lock() { | |||
auto value = ValueWeakRef::lock(); | |||
if (value) { | |||
return value.template as_ref<T>(); | |||
} else { | |||
return {}; | |||
} | |||
} | |||
}; | |||
// TODO: add proxy value type, which is meant to be reset in the end | |||
class ValueRefList { | |||
private: | |||
ValueRef* m_data = nullptr; | |||
size_t m_size = 0; | |||
std::aligned_storage_t<sizeof(ValueRef), alignof(ValueRef)> m_storage; | |||
private: | |||
void init(size_t nr_elems); | |||
ValueRef* inline_storage() { return reinterpret_cast<ValueRef*>(&m_storage); } | |||
public: | |||
ValueRefList() = default; | |||
ValueRefList(size_t nr_elems); | |||
ValueRefList(ValueRef item); | |||
ValueRefList(std::initializer_list<ValueRef> values); | |||
template <typename TIterator> | |||
ValueRefList(TIterator begin, TIterator end); | |||
ValueRefList(const ValueRefList& rhs); | |||
ValueRefList(ValueRefList&& rhs); | |||
ValueRefList& operator=(const ValueRefList& rhs); | |||
ValueRefList& operator=(ValueRefList&& rhs); | |||
~ValueRefList(); | |||
void clear(); | |||
ValueRef* begin() { return m_data; } | |||
ValueRef* end() { return m_data + m_size; } | |||
const ValueRef* cbegin() const { return m_data; } | |||
const ValueRef* cend() const { return m_data + m_size; } | |||
size_t size() const { return m_size; } | |||
ValueRef& at(size_t idx) { | |||
mgb_assert(idx < m_size); | |||
return m_data[idx]; | |||
} | |||
const ValueRef& at(size_t idx) const { | |||
mgb_assert(idx < m_size); | |||
return m_data[idx]; | |||
} | |||
ValueRef& operator[](size_t idx) { return m_data[idx]; } | |||
const ValueRef& operator[](size_t idx) const { return m_data[idx]; } | |||
ValueRef* data() { return m_data; } | |||
const ValueRef* data() const { return m_data; } | |||
bool empty() const { return m_size == 0; } | |||
ValueRef& front() { | |||
mgb_assert(m_size > 1); | |||
return m_data[0]; | |||
} | |||
ValueRef& back() { | |||
mgb_assert(m_size > 1); | |||
return m_data[m_size - 1]; | |||
} | |||
}; | |||
template <typename TIterator> | |||
ValueRefList::ValueRefList(TIterator begin, TIterator end) : ValueRefList(end - begin) { | |||
for (size_t i = 0; i < m_size; ++i) { | |||
m_data[i] = *(begin + i); | |||
} | |||
} | |||
inline ValueRefList::ValueRefList(ValueRef item) : m_data(inline_storage()), m_size(1) { | |||
new (m_data) ValueRef(); | |||
m_data[0] = std::move(item); | |||
} | |||
/*class ValueRefList : public SmallVector<ValueRef, 1> { | |||
public: | |||
using SmallVector::SmallVector; | |||
};*/ | |||
} // namespace imperative | |||
} // namespace mgb | |||