GitOrigin-RevId: 860028e1af
tags/v1.9.0
@@ -13,6 +13,7 @@ | |||||
#include "megbrain/imperative/transformations/trace.h" | #include "megbrain/imperative/transformations/trace.h" | ||||
#include "megbrain/imperative/utils/map.h" | #include "megbrain/imperative/utils/map.h" | ||||
#include "megbrain/imperative/utils/stats.h" | |||||
#include "./tensor.h" | #include "./tensor.h" | ||||
@@ -21,6 +21,7 @@ | |||||
#include "megbrain/imperative/transformations/symbol.h" | #include "megbrain/imperative/transformations/symbol.h" | ||||
#include "megbrain/imperative/transformations/trace.h" | #include "megbrain/imperative/transformations/trace.h" | ||||
#include "megbrain/imperative/utils/map.h" | #include "megbrain/imperative/utils/map.h" | ||||
#include "megbrain/imperative/utils/stats.h" | |||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
#include "megbrain/plugin/profiler.h" | #include "megbrain/plugin/profiler.h" | ||||
@@ -52,8 +53,48 @@ namespace mgb::imperative::python { | |||||
namespace { | namespace { | ||||
WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map; | WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map; | ||||
struct SymbolVarContext { | |||||
TransformationContext context; | |||||
cg::ComputingGraph* graph; | |||||
SymbolVarContext(cg::ComputingGraph* graph) : graph(graph) { | |||||
Transformation::swap_context(context); | |||||
} | |||||
void init() { | |||||
std::make_shared<SymbolTransformation>(graph)->register_at( | |||||
Transformation::top()); | |||||
std::make_shared<ScalarTransformation>()->register_at(Transformation::top()); | |||||
} | |||||
~SymbolVarContext() { Transformation::swap_context(context); } | |||||
}; | |||||
ValueRef symvar2val(py::handle py_symbol_var) { | |||||
auto* symbol_var = py_symbol_var.cast<PySymbolVar*>(); | |||||
ValueRef value = SymbolValue::make(symbol_var->m_node); | |||||
if (symbol_var->is_scalar) { | |||||
value = ScalarValue::make(value); | |||||
} | |||||
return value; | |||||
} | |||||
py::object val2symvar(py::handle typeobj, ValueRef value) { | |||||
bool is_scalar = false; | |||||
if (auto* scalar_value = value.as<ScalarValue>()) { | |||||
value = scalar_value->value(); | |||||
is_scalar = true; | |||||
} | |||||
auto* node = value.cast<SymbolValue>().node(); | |||||
auto py_symbol_var = | |||||
typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic)); | |||||
py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar; | |||||
return py_symbol_var; | |||||
} | } | ||||
} // namespace | |||||
interpreter::Interpreter::Channel* interpreter_for_py = nullptr; | interpreter::Interpreter::Channel* interpreter_for_py = nullptr; | ||||
PyTypeObject* py_tensor_type = nullptr; | PyTypeObject* py_tensor_type = nullptr; | ||||
PyObject *cpp_use_symbolic_shape, *cpp_astensor1d; | PyObject *cpp_use_symbolic_shape, *cpp_astensor1d; | ||||
@@ -91,36 +132,17 @@ PyObject* py_apply( | |||||
if (py::isinstance<PySymbolVar>(py::handle(args[0]))) { | if (py::isinstance<PySymbolVar>(py::handle(args[0]))) { | ||||
// swap to a special context to reuse scalar handle | // swap to a special context to reuse scalar handle | ||||
TransformationContext symbol_var_context; | |||||
Transformation::swap_context(symbol_var_context); | |||||
CleanupGuard _{[&] { Transformation::swap_context(symbol_var_context); }}; | |||||
auto* graph = | |||||
py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph(); | |||||
std::make_shared<SymbolTransformation>(graph)->register_at( | |||||
Transformation::top()); | |||||
std::make_shared<ScalarTransformation>()->register_at( | |||||
Transformation::top()); | |||||
SymbolVarContext context( | |||||
py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph()); | |||||
context.init(); | |||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
auto* py_input = py::handle(args[i]).cast<PySymbolVar*>(); | |||||
ValueRef input = SymbolValue::make(py_input->m_node); | |||||
if (py_input->is_scalar) { | |||||
input = ScalarValue::make(input); | |||||
} | |||||
tensors[i] = input; | |||||
tensors[i] = symvar2val(args[i]); | |||||
} | } | ||||
auto outputs = imperative::apply(*op, tensors); | auto outputs = imperative::apply(*op, tensors); | ||||
auto ret = pybind11::tuple(outputs.size()); | auto ret = pybind11::tuple(outputs.size()); | ||||
auto typeobj = py::handle(args[0]).get_type(); | auto typeobj = py::handle(args[0]).get_type(); | ||||
for (size_t i = 0; i < outputs.size(); ++i) { | for (size_t i = 0; i < outputs.size(); ++i) { | ||||
bool is_scalar = false; | |||||
if (auto* scalar_value = outputs[i].as<ScalarValue>()) { | |||||
outputs[i] = scalar_value->value(); | |||||
is_scalar = true; | |||||
} | |||||
auto* node = outputs[i].cast<SymbolValue>().node(); | |||||
ret[i] = typeobj( | |||||
pybind11::cast(node, pybind11::return_value_policy::automatic)); | |||||
py::handle(ret[i]).cast<PySymbolVar*>()->is_scalar = is_scalar; | |||||
ret[i] = val2symvar(typeobj, outputs[i]); | |||||
} | } | ||||
return ret.release().ptr(); | return ret.release().ptr(); | ||||
} | } | ||||
@@ -1537,17 +1559,29 @@ void init_tensor(py::module m) { | |||||
} | } | ||||
}); | }); | ||||
m.def("reduce_to_scalar", [](py::object op, py::object tensor) { | |||||
auto* tw = TensorWrapper::try_cast(tensor.ptr()); | |||||
auto make_scalar_shape = [&](CompNode device) { | |||||
return imperative::apply( | |||||
CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}), | |||||
HostStorage::make(device))[0]; | |||||
m.def("reduce_to_scalar", [](py::object op, py::object tensor) -> py::object { | |||||
auto reduce_to_scalar = [](const OpDef& op, const ValueRef& input) { | |||||
auto make_scalar_shape = [&](CompNode device) { | |||||
return imperative::apply( | |||||
CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}), | |||||
HostStorage::make(device))[0]; | |||||
}; | |||||
return imperative::apply(op, input, make_scalar_shape(*input.device()))[0]; | |||||
}; | }; | ||||
auto output = imperative::apply( | |||||
*op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data(), | |||||
make_scalar_shape(tw->m_tensor->comp_node()))[0]; | |||||
return TensorWrapper::make(py_tensor_type, output); | |||||
if (py::isinstance<PySymbolVar>(tensor)) { | |||||
auto* graph = tensor.cast<PySymbolVar*>()->m_node->owner_graph(); | |||||
SymbolVarContext context(graph); | |||||
context.init(); | |||||
auto output = reduce_to_scalar( | |||||
*op.cast<std::shared_ptr<OpDef>>(), symvar2val(tensor)); | |||||
auto typeobj = tensor.get_type(); | |||||
return val2symvar(typeobj, output); | |||||
} else { | |||||
auto* tw = TensorWrapper::try_cast(tensor.ptr()); | |||||
auto output = reduce_to_scalar( | |||||
*op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data()); | |||||
return TensorWrapper::make(py_tensor_type, output); | |||||
} | |||||
}); | }); | ||||
m.def("name_tensor", [](std::string name, py::object tensor) { | m.def("name_tensor", [](std::string name, py::object tensor) { | ||||
@@ -1557,7 +1591,7 @@ void init_tensor(py::module m) { | |||||
}); | }); | ||||
m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool { | m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool { | ||||
ValueRefList values(tensors.size()); | |||||
SmallVector<ValueRef> values(tensors.size()); | |||||
for (size_t i = 0; i < tensors.size(); ++i) { | for (size_t i = 0; i < tensors.size(); ++i) { | ||||
values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data(); | values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data(); | ||||
} | } | ||||
@@ -1570,17 +1604,16 @@ void init_tensor(py::module m) { | |||||
}); | }); | ||||
m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object { | m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object { | ||||
ValueRefList values(tensors.size()); | |||||
SmallVector<ValueRef> values(tensors.size()); | |||||
for (size_t i = 0; i < tensors.size(); ++i) { | for (size_t i = 0; i < tensors.size(); ++i) { | ||||
values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data(); | values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data(); | ||||
} | } | ||||
auto outputs = imperative::apply(GetGradKey(), values); | |||||
if (auto* grad_key_val = outputs[0].as<GradKeyValue>()) { | |||||
return py::reinterpret_borrow<py::object>( | |||||
GradKeyWrapper::wrap_t::pycast(GradKeyWrapper::get(*grad_key_val))); | |||||
} else { | |||||
auto output = imperative::apply(GetGradKey(), values)[0]; | |||||
if (!output) { | |||||
return py::none(); | return py::none(); | ||||
} | } | ||||
return py::reinterpret_borrow<py::object>(GradKeyWrapper::wrap_t::pycast( | |||||
GradKeyWrapper::get(output.cast<GradKeyValue>()))); | |||||
}); | }); | ||||
m.def("set_grad", [](py::object py_key, py::function backward_fn, | m.def("set_grad", [](py::object py_key, py::function backward_fn, | ||||
@@ -1612,7 +1645,7 @@ void init_tensor(py::module m) { | |||||
} | } | ||||
return input_grads; | return input_grads; | ||||
}; | }; | ||||
ValueRefList values(inputs.size() + outputs.size()); | |||||
SmallVector<ValueRef> values(inputs.size() + outputs.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | for (size_t i = 0; i < inputs.size(); ++i) { | ||||
values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data(); | values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data(); | ||||
} | } | ||||
@@ -1669,6 +1702,10 @@ void init_tensor(py::module m) { | |||||
return reprs; | return reprs; | ||||
}); | }); | ||||
m.def("print_stats", [] { imperative::Stats::print(); }); | |||||
m.def("reset_stats", [] { imperative::Stats::reset(); }); | |||||
py::register_exception<TraceError>(m, "TraceError"); | py::register_exception<TraceError>(m, "TraceError"); | ||||
} | } | ||||
@@ -67,7 +67,8 @@ struct TransformationManager { | |||||
} | } | ||||
}; | }; | ||||
class PyValue final : public MixinValueImpl<PyValue, pybind11::object> { | |||||
class PyValue final | |||||
: public MixinValueImpl<PyValue, ValueKind::Object, pybind11::object> { | |||||
public: | public: | ||||
using MixinValueImpl::MixinValueImpl; | using MixinValueImpl::MixinValueImpl; | ||||
@@ -14,13 +14,9 @@ | |||||
#include "megbrain/imperative/utils/debug.h" | #include "megbrain/imperative/utils/debug.h" | ||||
#include "megbrain/imperative/utils/helper.h" | #include "megbrain/imperative/utils/helper.h" | ||||
#include "megbrain/imperative/utils/map.h" | #include "megbrain/imperative/utils/map.h" | ||||
#include "megbrain/imperative/utils/stats.h" | |||||
namespace mgb { | namespace mgb { | ||||
void imperative_log_profile_begin(const char* message); | |||||
void imperative_log_profile(const char* message); | |||||
void imperative_log_profile_end(const char* message); | |||||
namespace imperative { | namespace imperative { | ||||
namespace { | namespace { | ||||
@@ -19,6 +19,7 @@ | |||||
#include "megbrain/imperative/ops/backward_graph.h" | #include "megbrain/imperative/ops/backward_graph.h" | ||||
#include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
#include "megbrain/imperative/ops/utility.h" | #include "megbrain/imperative/ops/utility.h" | ||||
#include "megbrain/imperative/utils/stats.h" | |||||
#include "megbrain/imperative/utils/to_string.h" | #include "megbrain/imperative/utils/to_string.h" | ||||
#include "../blob_manager_impl.h" | #include "../blob_manager_impl.h" | ||||
@@ -1,4 +1,5 @@ | |||||
#include "megbrain/imperative/transformation.h" | #include "megbrain/imperative/transformation.h" | ||||
#include "megbrain/imperative/utils/stats.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -11,6 +11,7 @@ | |||||
#include "megbrain/imperative/transformations/eval.h" | #include "megbrain/imperative/transformations/eval.h" | ||||
#include "megbrain/imperative/transformations/grad.h" | #include "megbrain/imperative/transformations/grad.h" | ||||
#include "megbrain/imperative/utils/stats.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -40,9 +41,6 @@ ShapeValue::ref_t InterpreterInfo::shape() const { | |||||
ValueRefList InterpreterTransformation::apply_op( | ValueRefList InterpreterTransformation::apply_op( | ||||
const ApplyOp& apply_op, Span<ValueRef> inputs) { | const ApplyOp& apply_op, Span<ValueRef> inputs) { | ||||
if (apply_op.op().same_type<FastpathCopy>()) { | |||||
return {inputs[0]}; | |||||
} | |||||
SmallVector<Handle> input_handles; | SmallVector<Handle> input_handles; | ||||
SmallVector<Handle> output_handles; | SmallVector<Handle> output_handles; | ||||
CleanupGuard _{[&] { | CleanupGuard _{[&] { | ||||
@@ -111,7 +109,11 @@ ValueRefList InterpreterTransformation::apply_create_tensor( | |||||
ValueRefList InterpreterTransformation::apply_transformation( | ValueRefList InterpreterTransformation::apply_transformation( | ||||
const Operator& op, Span<ValueRef> inputs) { | const Operator& op, Span<ValueRef> inputs) { | ||||
if (auto* op_val = op.as<ApplyOp>()) { | if (auto* op_val = op.as<ApplyOp>()) { | ||||
return apply_op(*op_val, inputs); | |||||
if (op_val->op().same_type<FastpathCopy>()) { | |||||
return inputs[0]; | |||||
} else { | |||||
return apply_op(*op_val, inputs); | |||||
} | |||||
} else if (auto* get_attr = op.as<GetAttr>()) { | } else if (auto* get_attr = op.as<GetAttr>()) { | ||||
return apply_get_attr(*get_attr, inputs); | return apply_get_attr(*get_attr, inputs); | ||||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | } else if (auto* create_tensor = op.as<CreateTensor>()) { | ||||
@@ -11,8 +11,11 @@ | |||||
#include "megbrain/imperative/transformations/grad.h" | #include "megbrain/imperative/transformations/grad.h" | ||||
#include <variant> | |||||
#include "megbrain/imperative/graph_cache.h" | #include "megbrain/imperative/graph_cache.h" | ||||
#include "megbrain/imperative/resource_manager.h" | #include "megbrain/imperative/resource_manager.h" | ||||
#include "megbrain/imperative/utils/stats.h" | |||||
#include <range/v3/all.hpp> | #include <range/v3/all.hpp> | ||||
@@ -20,20 +23,21 @@ namespace mgb { | |||||
namespace imperative { | namespace imperative { | ||||
static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_graph( | static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_graph( | ||||
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs, | |||||
const OpDef& op, Span<ValueRef> inputs, Span<ValueRef> outputs, | |||||
Span<bool> inputs_require_grad) { | Span<bool> inputs_require_grad) { | ||||
// hash | // hash | ||||
using OptimizedBackwardGraphCache = OpMethResultCache< | using OptimizedBackwardGraphCache = OpMethResultCache< | ||||
std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>; | std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>; | ||||
thread_local auto& cache = | thread_local auto& cache = | ||||
*ResourceManager::create_local<OptimizedBackwardGraphCache>(); | *ResourceManager::create_local<OptimizedBackwardGraphCache>(); | ||||
OptimizedBackwardGraphCache::key_t cache_key{op}; | |||||
OptimizedBackwardGraphCache::key_t cache_key{op.shared_from_this()}; | |||||
SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs; | SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs; | ||||
std::get<0>(cache_key.extras) = inputs_require_grad.copy_into<SmallVector<bool>>(); | |||||
cache_key.extra<0>() = inputs_require_grad.copy_into<SmallVector<bool>>(); | |||||
input_descs.resize(inputs.size()); | input_descs.resize(inputs.size()); | ||||
// some overhead, consider simplify LogicalTensorDesc | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | for (size_t i = 0; i < inputs.size(); ++i) { | ||||
input_descs[i].layout.dtype = inputs[i].dtype().cast<DTypeValue>(); | |||||
input_descs[i].comp_node = inputs[i].device().cast<CompNodeValue>(); | |||||
input_descs[i].layout.dtype = *inputs[i].dtype(); | |||||
input_descs[i].comp_node = *inputs[i].device(); | |||||
} | } | ||||
auto iter = cache.find(cache_key); | auto iter = cache.find(cache_key); | ||||
@@ -45,7 +49,7 @@ static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_gra | |||||
SmallVector<bool> output_has_grad(outputs.size(), true); | SmallVector<bool> output_has_grad(outputs.size(), true); | ||||
std::shared_ptr<OptimizedBackwardGraphResult> ret; | std::shared_ptr<OptimizedBackwardGraphResult> ret; | ||||
auto bg = OpDef::make_backward_graph( | auto bg = OpDef::make_backward_graph( | ||||
*op, input_descs, std::get<0>(cache_key.extras), output_has_grad); | |||||
op, input_descs, std::get<0>(cache_key.extras), output_has_grad); | |||||
if (!bg.graph.empty()) { | if (!bg.graph.empty()) { | ||||
ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | ||||
} | } | ||||
@@ -235,7 +239,7 @@ GradValue::ref_t GradKey::attach( | |||||
} else { | } else { | ||||
GradSlotPtr grad_slot; | GradSlotPtr grad_slot; | ||||
auto& grad_fn = grad_slot.m_fn; | auto& grad_fn = grad_slot.m_fn; | ||||
grad_fn = std::make_shared<GradFn>(); | |||||
grad_fn = LocalPtr<GradFn>::make(); | |||||
grad_fn->m_key = shared_from_this(); | grad_fn->m_key = shared_from_this(); | ||||
grad_fn->m_slots.resize(1); | grad_fn->m_slots.resize(1); | ||||
grad_slot.m_index = 0; | grad_slot.m_index = 0; | ||||
@@ -260,17 +264,21 @@ ValueRefList GradTransformation::apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) { | const Operator& op, Span<ValueRef> inputs) { | ||||
auto fallback = [&] { | auto fallback = [&] { | ||||
ValueRefList unwrapped_inputs(inputs.size()); | ValueRefList unwrapped_inputs(inputs.size()); | ||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (auto grad_value = as_grad_value(inputs[i])) { | |||||
unwrapped_inputs[i] = grad_value->m_value; | |||||
} else { | |||||
unwrapped_inputs[i] = inputs[i]; | |||||
{ | |||||
// overhead | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (auto&& grad_value = as_grad_value(inputs[i])) { | |||||
unwrapped_inputs[i] = grad_value->m_value; | |||||
} else { | |||||
unwrapped_inputs[i] = inputs[i]; | |||||
} | |||||
} | } | ||||
} | } | ||||
return imperative::apply(op, unwrapped_inputs); | return imperative::apply(op, unwrapped_inputs); | ||||
}; | }; | ||||
if (auto* get_attr = op.as<GetAttr>()) { | |||||
if (auto grad_value = as_grad_value(inputs.item())) { | |||||
if (op.is<GetAttr>()) { | |||||
// overhead | |||||
if (auto&& grad_value = as_grad_value(inputs.item())) { | |||||
return imperative::apply(op, grad_value->m_value); | return imperative::apply(op, grad_value->m_value); | ||||
} else { | } else { | ||||
return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
@@ -281,28 +289,29 @@ ValueRefList GradTransformation::apply_transformation( | |||||
} | } | ||||
if (auto* op_val = op.as<ApplyOp>()) { | if (auto* op_val = op.as<ApplyOp>()) { | ||||
size_t nr_require_grad = 0; | size_t nr_require_grad = 0; | ||||
SmallVector<bool> require_grads; | |||||
for (auto&& input : inputs) { | |||||
if (is_grad_value(input)) { | |||||
SmallVector<bool> require_grads(inputs.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (is_grad_value(inputs[i])) { | |||||
nr_require_grad++; | nr_require_grad++; | ||||
require_grads.push_back(true); | |||||
require_grads[i] = true; | |||||
} else { | } else { | ||||
require_grads.push_back(false); | |||||
require_grads[i] = false; | |||||
} | } | ||||
} | } | ||||
if (nr_require_grad == 0) { | if (nr_require_grad == 0) { | ||||
return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
} | } | ||||
ValueRefList captured_inputs(inputs.size()); | |||||
SmallVector<ValueRef> captured_inputs(inputs.size()); | |||||
SmallVector<bool> inputs_require_grad(inputs.size()); | SmallVector<bool> inputs_require_grad(inputs.size()); | ||||
// capture value so that trace could assume input as same | // capture value so that trace could assume input as same | ||||
auto capture_value = [](ValueRef value) { | |||||
auto capture_value = [](const ValueRef& value) { | |||||
// TODO: fastpath copy shouldn't be an OpDef | // TODO: fastpath copy shouldn't be an OpDef | ||||
return imperative::apply(ApplyOp(*FastpathCopy::make()), {&value, 1})[0]; | |||||
static auto fastpath_copy = FastpathCopy::make(); | |||||
return imperative::apply(ApplyOp(*fastpath_copy), value)[0]; | |||||
}; | }; | ||||
for (size_t i = 0; i < inputs.size(); ++i) { | for (size_t i = 0; i < inputs.size(); ++i) { | ||||
auto& input = inputs[i]; | auto& input = inputs[i]; | ||||
if (auto grad_value = as_grad_value(input)) { | |||||
if (auto&& grad_value = as_grad_value(input)) { | |||||
captured_inputs[i] = capture_value(grad_value->m_value); | captured_inputs[i] = capture_value(grad_value->m_value); | ||||
inputs_require_grad[i] = true; | inputs_require_grad[i] = true; | ||||
} else { | } else { | ||||
@@ -310,32 +319,28 @@ ValueRefList GradTransformation::apply_transformation( | |||||
inputs_require_grad[i] = false; | inputs_require_grad[i] = false; | ||||
} | } | ||||
} | } | ||||
decltype(std::declval<GradFn>().m_backward) backward_storage; | |||||
// copy grad_fn->m_backward is expensive | |||||
auto grad_fn = LocalPtr<GradFn>::make(); | |||||
auto& backward_storage = grad_fn->m_backward; | |||||
auto outputs = [&] { | auto outputs = [&] { | ||||
auto backward_rule = | auto backward_rule = | ||||
CustomBackward::lookup_grad_rule(op_val->op().dyn_typeinfo()); | CustomBackward::lookup_grad_rule(op_val->op().dyn_typeinfo()); | ||||
if (backward_rule) { | if (backward_rule) { | ||||
CustomBackward backward; | CustomBackward backward; | ||||
auto optional_outputs = backward_rule( | auto optional_outputs = backward_rule( | ||||
op_val->op(), {captured_inputs.data(), captured_inputs.size()}, | |||||
{inputs_require_grad.data(), inputs_require_grad.size()}, | |||||
backward); | |||||
op_val->op(), captured_inputs, inputs_require_grad, backward); | |||||
if (optional_outputs) { | if (optional_outputs) { | ||||
backward_storage = backward; | backward_storage = backward; | ||||
// backward by rule | // backward by rule | ||||
return *optional_outputs; | return *optional_outputs; | ||||
} | } | ||||
} | } | ||||
auto outputs = imperative::apply( | |||||
op, {captured_inputs.begin(), captured_inputs.end()}); | |||||
auto outputs = imperative::apply(op, captured_inputs); | |||||
auto backward_graph = make_optimized_backward_graph( | auto backward_graph = make_optimized_backward_graph( | ||||
op.cast<ApplyOp>().op().shared_from_this(), | |||||
{captured_inputs.begin(), captured_inputs.end()}, | |||||
{outputs.data(), outputs.size()}, | |||||
{inputs_require_grad.data(), inputs_require_grad.size()}); | |||||
op_val->op(), captured_inputs, outputs, inputs_require_grad); | |||||
if (backward_graph) { | if (backward_graph) { | ||||
backward_storage = BackwardGraphWithClosure( | backward_storage = BackwardGraphWithClosure( | ||||
backward_graph, op.cast<ApplyOp>().op().shared_from_this(), | |||||
backward_graph, op_val->op().shared_from_this(), | |||||
{captured_inputs.begin(), captured_inputs.end()}, | {captured_inputs.begin(), captured_inputs.end()}, | ||||
{outputs.data(), outputs.size()}); | {outputs.data(), outputs.size()}); | ||||
// backward by make_backward_graph | // backward by make_backward_graph | ||||
@@ -348,18 +353,17 @@ ValueRefList GradTransformation::apply_transformation( | |||||
if (std::holds_alternative<std::monostate>(backward_storage)) { | if (std::holds_alternative<std::monostate>(backward_storage)) { | ||||
return outputs; | return outputs; | ||||
} | } | ||||
auto grad_fn = std::make_shared<GradFn>(); | |||||
grad_fn->m_key = m_key; | grad_fn->m_key = m_key; | ||||
grad_fn->m_slots.resize(outputs.size()); | grad_fn->m_slots.resize(outputs.size()); | ||||
grad_fn->m_backward = backward_storage; | |||||
mgb_assert(!outputs.empty()); | mgb_assert(!outputs.empty()); | ||||
grad_fn->m_dests.reserve(inputs.size()); | grad_fn->m_dests.reserve(inputs.size()); | ||||
// clang-format off | // clang-format off | ||||
std::visit([&](auto& backward) { | |||||
auto visitor = [&](auto& backward) { | |||||
using T = std::decay_t<decltype(backward)>; | using T = std::decay_t<decltype(backward)>; | ||||
if constexpr (std::is_same_v<T, std::monostate>) { | if constexpr (std::is_same_v<T, std::monostate>) { | ||||
mgb_throw(AssertionError, "invalid backward"); | mgb_throw(AssertionError, "invalid backward"); | ||||
} else { | } else { | ||||
// little overhead | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | for (size_t i = 0; i < inputs.size(); ++i) { | ||||
if (backward.input_has_grad(i) && require_grads[i]) { | if (backward.input_has_grad(i) && require_grads[i]) { | ||||
auto& input_grad_slot = | auto& input_grad_slot = | ||||
@@ -373,19 +377,23 @@ ValueRefList GradTransformation::apply_transformation( | |||||
} | } | ||||
for (size_t i = 0; i < outputs.size(); ++i) { | for (size_t i = 0; i < outputs.size(); ++i) { | ||||
if (backward.output_requires_grad(i)) { | if (backward.output_requires_grad(i)) { | ||||
// little overhead: Value::make | |||||
auto grad_value = GradValue::make(outputs[i], m_key, GradSlotPtr{grad_fn, i}); | auto grad_value = GradValue::make(outputs[i], m_key, GradSlotPtr{grad_fn, i}); | ||||
outputs[i] = record_grad(grad_value); | outputs[i] = record_grad(grad_value); | ||||
} | } | ||||
} | } | ||||
} | } | ||||
}, grad_fn->m_backward); | |||||
}; | |||||
// std::visit may be slightly slower than direct if | |||||
std::visit(visitor, backward_storage); | |||||
// clang-format on | // clang-format on | ||||
mgb_assert(!grad_fn->m_slots.empty()); | mgb_assert(!grad_fn->m_slots.empty()); | ||||
m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()}); | m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()}); | ||||
return outputs; | return outputs; | ||||
} else if (op.is<CreateTensor>()) { | } else if (op.is<CreateTensor>()) { | ||||
return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
} else if (auto* attach_grad = op.as<AttachGrad>()) { | |||||
} | |||||
if (auto* attach_grad = op.as<AttachGrad>()) { | |||||
if (!has_key(attach_grad->key())) { | if (!has_key(attach_grad->key())) { | ||||
return fallback(); | return fallback(); | ||||
} | } | ||||
@@ -408,7 +416,7 @@ ValueRefList GradTransformation::apply_transformation( | |||||
return {}; | return {}; | ||||
} else if (auto* is_attached_to = op.as<IsAttachedTo>()) { | } else if (auto* is_attached_to = op.as<IsAttachedTo>()) { | ||||
if (has_key(is_attached_to->key())) { | if (has_key(is_attached_to->key())) { | ||||
if (auto grad_value = as_grad_value(inputs[0])) { | |||||
if (auto&& grad_value = as_grad_value(inputs[0])) { | |||||
// TODO: assert grad_fn | // TODO: assert grad_fn | ||||
return {BoolValue::make(true)}; | return {BoolValue::make(true)}; | ||||
} | } | ||||
@@ -416,7 +424,7 @@ ValueRefList GradTransformation::apply_transformation( | |||||
return {BoolValue::make(false)}; | return {BoolValue::make(false)}; | ||||
} else if (auto* set_grad = op.as<SetGrad>()) { | } else if (auto* set_grad = op.as<SetGrad>()) { | ||||
// TODO: merge SetGrad and ApplyOp | // TODO: merge SetGrad and ApplyOp | ||||
auto grad_fn = std::make_shared<GradFn>(); | |||||
auto grad_fn = LocalPtr<GradFn>::make(); | |||||
auto& backward = | auto& backward = | ||||
std::get<CustomBackward>(grad_fn->m_backward = CustomBackward()); | std::get<CustomBackward>(grad_fn->m_backward = CustomBackward()); | ||||
size_t nr_inputs = set_grad->nr_inputs(); | size_t nr_inputs = set_grad->nr_inputs(); | ||||
@@ -433,7 +441,7 @@ ValueRefList GradTransformation::apply_transformation( | |||||
grad_fn->m_slots.resize(nr_outputs); | grad_fn->m_slots.resize(nr_outputs); | ||||
grad_fn->m_dests.reserve(nr_inputs); | grad_fn->m_dests.reserve(nr_inputs); | ||||
for (size_t i = 0; i < nr_inputs; ++i) { | for (size_t i = 0; i < nr_inputs; ++i) { | ||||
if (auto grad_value = as_grad_value(inputs_[i])) { | |||||
if (auto&& grad_value = as_grad_value(inputs_[i])) { | |||||
auto& input_grad_slot = grad_value->m_slot; | auto& input_grad_slot = grad_value->m_slot; | ||||
grad_fn->m_dests.emplace_back(grad_value->m_slot); | grad_fn->m_dests.emplace_back(grad_value->m_slot); | ||||
grad_fn->m_dests.back().m_producer_record.insert_after( | grad_fn->m_dests.back().m_producer_record.insert_after( | ||||
@@ -461,21 +469,21 @@ ValueRefList GradTransformation::apply_transformation( | |||||
} | } | ||||
return {FunctionValue::make(make_backward_closure(inputs))}; | return {FunctionValue::make(make_backward_closure(inputs))}; | ||||
} else if (op.is<DetachGrad>()) { | } else if (op.is<DetachGrad>()) { | ||||
if (auto grad_value = as_grad_value(inputs[0])) { | |||||
if (auto&& grad_value = as_grad_value(inputs[0])) { | |||||
return {grad_value->m_value}; | return {grad_value->m_value}; | ||||
} else { | } else { | ||||
return {inputs[0]}; | return {inputs[0]}; | ||||
} | } | ||||
} else if (op.is<GetGradKey>()) { | } else if (op.is<GetGradKey>()) { | ||||
for (auto&& input : inputs) { | for (auto&& input : inputs) { | ||||
if (auto grad_value = as_grad_value(input)) { | |||||
if (auto&& grad_value = as_grad_value(input)) { | |||||
return {GradKeyValue::make(grad_value->m_key)}; | return {GradKeyValue::make(grad_value->m_key)}; | ||||
} | } | ||||
} | } | ||||
return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
} else if (op.kind() == Operator::IdentityLike) { | } else if (op.kind() == Operator::IdentityLike) { | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
if (auto grad_value = as_grad_value(inputs[0])) { | |||||
if (auto&& grad_value = as_grad_value(inputs[0])) { | |||||
auto output = imperative::apply(op, grad_value->m_value)[0]; | auto output = imperative::apply(op, grad_value->m_value)[0]; | ||||
auto grad_output = GradValue::make( | auto grad_output = GradValue::make( | ||||
output, grad_value->key(), grad_value->slot_for(m_key)); | output, grad_value->key(), grad_value->slot_for(m_key)); | ||||
@@ -493,7 +501,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) { | |||||
auto grad_key = m_key; | auto grad_key = m_key; | ||||
std::vector<GradSlotPtr> y_slots; | std::vector<GradSlotPtr> y_slots; | ||||
for (auto&& y : ys) { | for (auto&& y : ys) { | ||||
if (auto grad_value = as_grad_value(y)) { | |||||
if (auto&& grad_value = as_grad_value(y)) { | |||||
y_slots.push_back(grad_value->slot_for(grad_key)); | y_slots.push_back(grad_value->slot_for(grad_key)); | ||||
} else { | } else { | ||||
y_slots.emplace_back(); | y_slots.emplace_back(); | ||||
@@ -13,6 +13,7 @@ | |||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/imperative/ops/utility.h" | #include "megbrain/imperative/ops/utility.h" | ||||
#include "megbrain/imperative/utils/stats.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -185,7 +186,7 @@ ValueRefList subtensor_rule( | |||||
bool is_scalar; | bool is_scalar; | ||||
mgb_assert(!inputs_mask[0], "subtensor shouldn't have scalar input"); | mgb_assert(!inputs_mask[0], "subtensor shouldn't have scalar input"); | ||||
if (auto shape = input.shape()) { | if (auto shape = input.shape()) { | ||||
size_t ndim = input.shape()->ndim; | |||||
size_t ndim = shape->ndim; | |||||
for (auto&& [axis, begin, end, step, idx] : subtensor.items) { | for (auto&& [axis, begin, end, step, idx] : subtensor.items) { | ||||
if (idx) { | if (idx) { | ||||
ndim--; | ndim--; | ||||
@@ -193,6 +194,7 @@ ValueRefList subtensor_rule( | |||||
} | } | ||||
is_scalar = ndim == 0; | is_scalar = ndim == 0; | ||||
} else { | } else { | ||||
// assume not scalar | |||||
is_scalar = false; | is_scalar = false; | ||||
} | } | ||||
auto outputs = imperative::apply(subtensor, inputs); | auto outputs = imperative::apply(subtensor, inputs); | ||||
@@ -341,12 +343,16 @@ ValueRefList ScalarTransformation::apply_transformation( | |||||
if (auto* get_attr = op.as<GetAttr>()) { | if (auto* get_attr = op.as<GetAttr>()) { | ||||
// fastpath for GetAttr | // fastpath for GetAttr | ||||
return apply_get_attr(*get_attr, inputs); | return apply_get_attr(*get_attr, inputs); | ||||
} else if (auto* apply_op = op.as<ApplyOp>()) { | |||||
if (apply_op->op().same_type<FastpathCopy>()) { | |||||
return inputs[0]; | |||||
} | |||||
} | } | ||||
size_t nr_inputs = inputs.size(); | size_t nr_inputs = inputs.size(); | ||||
ValueRefList unwrapped_inputs(nr_inputs); | ValueRefList unwrapped_inputs(nr_inputs); | ||||
bool inputs_mask[nr_inputs]; | |||||
SmallVector<bool> inputs_mask(nr_inputs); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | for (size_t i = 0; i < inputs.size(); ++i) { | ||||
if (auto scalar_value = inputs[i].as_ref<ScalarValue>()) { | |||||
if (auto&& scalar_value = inputs[i].as_ref<ScalarValue>()) { | |||||
unwrapped_inputs[i] = scalar_value->value(); | unwrapped_inputs[i] = scalar_value->value(); | ||||
inputs_mask[i] = true; | inputs_mask[i] = true; | ||||
} else { | } else { | ||||
@@ -358,8 +364,7 @@ ValueRefList ScalarTransformation::apply_transformation( | |||||
if (auto apply_op = op.as<ApplyOp>()) { | if (auto apply_op = op.as<ApplyOp>()) { | ||||
auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); | auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); | ||||
if (iter != scalar_rules.end()) { | if (iter != scalar_rules.end()) { | ||||
return iter->second( | |||||
apply_op->op(), unwrapped_inputs, {inputs_mask, nr_inputs}); | |||||
return iter->second(apply_op->op(), unwrapped_inputs, inputs_mask); | |||||
} else { | } else { | ||||
// TODO: repeat op | // TODO: repeat op | ||||
return fallback(); | return fallback(); | ||||
@@ -215,8 +215,8 @@ ValueRefList::ValueRefList(size_t nr_elems) { | |||||
init(nr_elems); | init(nr_elems); | ||||
} | } | ||||
ValueRefList::ValueRefList(std::initializer_list<ValueRef> values) | |||||
: ValueRefList(values.begin(), values.end()) {} | |||||
/*ValueRefList::ValueRefList(std::initializer_list<ValueRef> values) | |||||
: ValueRefList(values.begin(), values.end()) {}*/ | |||||
ValueRefList::ValueRefList(const ValueRefList& rhs) | ValueRefList::ValueRefList(const ValueRefList& rhs) | ||||
: ValueRefList(rhs.cbegin(), rhs.cend()) {} | : ValueRefList(rhs.cbegin(), rhs.cend()) {} | ||||
@@ -25,14 +25,16 @@ class GradKey; | |||||
using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>; | using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>; | ||||
class ShapeValue final : public MixinValueImpl<ShapeValue, ValueShape> { | |||||
class ShapeValue final | |||||
: public MixinValueImpl<ShapeValue, ValueKind::Primitive, ValueShape> { | |||||
public: | public: | ||||
using MixinValueImpl::MixinValueImpl; | using MixinValueImpl::MixinValueImpl; | ||||
std::string to_string() const override; | std::string to_string() const override; | ||||
}; | }; | ||||
class CompNodeValue final : public MixinValueImpl<CompNodeValue, CompNode> { | |||||
class CompNodeValue final | |||||
: public MixinValueImpl<CompNodeValue, ValueKind::Primitive, CompNode> { | |||||
public: | public: | ||||
using MixinValueImpl::MixinValueImpl; | using MixinValueImpl::MixinValueImpl; | ||||
@@ -40,7 +42,7 @@ public: | |||||
}; | }; | ||||
// TODO: override factory method | // TODO: override factory method | ||||
class BoolValue final : public ValueImpl<BoolValue> { | |||||
class BoolValue final : public ValueImpl<BoolValue, ValueKind::Primitive> { | |||||
private: | private: | ||||
std::optional<bool> m_value; | std::optional<bool> m_value; | ||||
@@ -53,14 +55,17 @@ public: | |||||
void clear() override { m_value.reset(); } | void clear() override { m_value.reset(); } | ||||
}; | }; | ||||
class HostStorage final : public MixinValueImpl<HostStorage, HostTensorStorage> { | |||||
class HostStorage final | |||||
: public MixinValueImpl<HostStorage, ValueKind::Primitive, HostTensorStorage> { | |||||
public: | public: | ||||
using MixinValueImpl::MixinValueImpl; | using MixinValueImpl::MixinValueImpl; | ||||
std::string to_string() const override; | std::string to_string() const override; | ||||
}; | }; | ||||
class DeviceStorage final : public MixinValueImpl<DeviceStorage, DeviceTensorStorage> { | |||||
class DeviceStorage final | |||||
: public MixinValueImpl< | |||||
DeviceStorage, ValueKind::Primitive, DeviceTensorStorage> { | |||||
public: | public: | ||||
using MixinValueImpl::MixinValueImpl; | using MixinValueImpl::MixinValueImpl; | ||||
@@ -71,7 +76,7 @@ public: | |||||
* \brief like HostTensorND mixin, but allow scalar value | * \brief like HostTensorND mixin, but allow scalar value | ||||
* | * | ||||
*/ | */ | ||||
class HostValue final : public ValueImpl<HostValue> { | |||||
class HostValue final : public ValueImpl<HostValue, ValueKind::Primitive> { | |||||
private: | private: | ||||
DType m_dtype; | DType m_dtype; | ||||
ValueShape m_shape; | ValueShape m_shape; | ||||
@@ -94,9 +99,9 @@ public: | |||||
} | } | ||||
DType dtype() const { return m_dtype; } | DType dtype() const { return m_dtype; } | ||||
ValueShape shape() const { return m_shape; } | |||||
const ValueShape& shape() const { return m_shape; } | |||||
CompNode device() const { return m_storage.comp_node(); } | CompNode device() const { return m_storage.comp_node(); } | ||||
HostTensorStorage storage() const { return m_storage; } | |||||
const HostTensorStorage& storage() const { return m_storage; } | |||||
DTypeScalar item() const { | DTypeScalar item() const { | ||||
mgb_assert(m_shape.is_scalar()); | mgb_assert(m_shape.is_scalar()); | ||||
return DTypeScalar::make_from_raw(m_dtype, m_storage.ptr()); | return DTypeScalar::make_from_raw(m_dtype, m_storage.ptr()); | ||||
@@ -109,7 +114,7 @@ public: | |||||
* \brief like DeviceTensorND mixin, but allow scalar value | * \brief like DeviceTensorND mixin, but allow scalar value | ||||
* | * | ||||
*/ | */ | ||||
class DeviceValue final : public ValueImpl<DeviceValue> { | |||||
class DeviceValue final : public ValueImpl<DeviceValue, ValueKind::Primitive> { | |||||
private: | private: | ||||
DType m_dtype; | DType m_dtype; | ||||
ValueShape m_shape; | ValueShape m_shape; | ||||
@@ -117,8 +122,8 @@ private: | |||||
public: | public: | ||||
DeviceValue(DType dtype, ValueShape shape, DeviceTensorStorage storage) | DeviceValue(DType dtype, ValueShape shape, DeviceTensorStorage storage) | ||||
: m_dtype(dtype), m_shape(shape), m_storage(storage) {} | |||||
DeviceValue(DeviceTensorND value) | |||||
: m_dtype(dtype), m_shape(shape), m_storage(std::move(storage)) {} | |||||
DeviceValue(const DeviceTensorND& value) | |||||
: DeviceValue( | : DeviceValue( | ||||
value.dtype(), ValueShape::from(value.shape()), value.storage()) { | value.dtype(), ValueShape::from(value.shape()), value.storage()) { | ||||
} | } | ||||
@@ -132,28 +137,31 @@ public: | |||||
} | } | ||||
DType dtype() const { return m_dtype; } | DType dtype() const { return m_dtype; } | ||||
ValueShape shape() const { return m_shape; } | |||||
const ValueShape& shape() const { return m_shape; } | |||||
CompNode device() const { return m_storage.comp_node(); } | CompNode device() const { return m_storage.comp_node(); } | ||||
DeviceTensorStorage storage() const { return m_storage; } | |||||
const DeviceTensorStorage& storage() const { return m_storage; } | |||||
DeviceTensorND as_nd(bool allow_scalar = false) const; | DeviceTensorND as_nd(bool allow_scalar = false) const; | ||||
}; | }; | ||||
class FunctionValue final : public MixinValueImpl<FunctionValue, GenericFunction> { | |||||
class FunctionValue final | |||||
: public MixinValueImpl<FunctionValue, ValueKind::Primitive, GenericFunction> { | |||||
public: | public: | ||||
using MixinValueImpl::MixinValueImpl; | using MixinValueImpl::MixinValueImpl; | ||||
std::string to_string() const override; | std::string to_string() const override; | ||||
}; | }; | ||||
class DTypeValue final : public MixinValueImpl<DTypeValue, DType> { | |||||
class DTypeValue final | |||||
: public MixinValueImpl<DTypeValue, ValueKind::Primitive, DType> { | |||||
public: | public: | ||||
using MixinValueImpl::MixinValueImpl; | using MixinValueImpl::MixinValueImpl; | ||||
std::string to_string() const override; | std::string to_string() const override; | ||||
}; | }; | ||||
class StringValue final : public MixinValueImpl<StringValue, std::string> { | |||||
class StringValue final | |||||
: public MixinValueImpl<StringValue, ValueKind::Primitive, std::string> { | |||||
public: | public: | ||||
using MixinValueImpl::MixinValueImpl; | using MixinValueImpl::MixinValueImpl; | ||||
@@ -171,7 +179,8 @@ public: | |||||
std::string message() const { return m_message; } | std::string message() const { return m_message; } | ||||
}; | }; | ||||
class ErrorValue final : public MixinValueImpl<ErrorValue, Error> { | |||||
class ErrorValue final | |||||
: public MixinValueImpl<ErrorValue, ValueKind::Primitive, Error> { | |||||
public: | public: | ||||
using MixinValueImpl::MixinValueImpl; | using MixinValueImpl::MixinValueImpl; | ||||
@@ -47,9 +47,14 @@ constexpr bool is_all_value_ref_v = | |||||
(... && (std::is_base_of_v<ValueRef, std::decay_t<TArgs>> || | (... && (std::is_base_of_v<ValueRef, std::decay_t<TArgs>> || | ||||
std::is_same_v<ValueRef, std::decay_t<TArgs>>)); | std::is_same_v<ValueRef, std::decay_t<TArgs>>)); | ||||
template <typename T> | |||||
static ValueRefList apply(T&& op, const ValueRef& arg) { | |||||
return imperative::apply(std::forward<T&&>(op), Span<ValueRef>{&arg, 1}); | |||||
} | |||||
template <typename T, typename... TArgs> | template <typename T, typename... TArgs> | ||||
static auto apply(T&& op, TArgs&&... args) | |||||
-> std::enable_if_t<is_all_value_ref_v<TArgs...>, ValueRefList> { | |||||
static auto apply(T&& op, TArgs&&... args) -> std::enable_if_t< | |||||
is_all_value_ref_v<TArgs...> && sizeof...(args) != 1, ValueRefList> { | |||||
ValueRef args_arr[sizeof...(TArgs)] = {std::forward<TArgs&&>(args)...}; | ValueRef args_arr[sizeof...(TArgs)] = {std::forward<TArgs&&>(args)...}; | ||||
return imperative::apply( | return imperative::apply( | ||||
std::forward<T&&>(op), | std::forward<T&&>(op), | ||||
@@ -54,6 +54,11 @@ struct OpMethArgs { | |||||
return extras == rhs.extras; | return extras == rhs.extras; | ||||
} | } | ||||
template <size_t i> | |||||
auto& extra() { | |||||
return std::get<i>(extras); | |||||
} | |||||
struct hash_t { | struct hash_t { | ||||
size_t operator()(const OpMethArgs& key) const { return key.hash(); } | size_t operator()(const OpMethArgs& key) const { return key.hash(); } | ||||
}; | }; | ||||
@@ -60,7 +60,7 @@ public: | |||||
}; | }; | ||||
class InterpreterValue final | class InterpreterValue final | ||||
: public MixinValueImpl<InterpreterValue, InterpreterInfo> { | |||||
: public MixinValueImpl<InterpreterValue, ValueKind::Object, InterpreterInfo> { | |||||
public: | public: | ||||
using MixinValueImpl::MixinValueImpl; | using MixinValueImpl::MixinValueImpl; | ||||
@@ -104,37 +104,15 @@ struct ToStringTrait<GradSlot> { | |||||
std::string operator()(const GradSlot& value) const { return value.to_string(); } | std::string operator()(const GradSlot& value) const { return value.to_string(); } | ||||
}; | }; | ||||
class GradFn { | |||||
private: | |||||
std::weak_ptr<GradKey> m_key; | |||||
std::vector<GradSlot> m_slots; | |||||
std::vector<GradSlotProducerPtr> m_dests; | |||||
std::variant<std::monostate, BackwardGraphWithClosure, CustomBackward> m_backward; | |||||
public: | |||||
void clear() { | |||||
m_key.reset(); | |||||
m_slots.clear(); | |||||
m_dests.clear(); | |||||
m_backward.emplace<std::monostate>(); | |||||
} | |||||
std::string to_string() const; | |||||
friend class GradSlotPtr; | |||||
friend class GradKey; | |||||
friend class GradTransformation; | |||||
}; | |||||
class GradSlotPtr { | class GradSlotPtr { | ||||
private: | private: | ||||
std::shared_ptr<GradFn> m_fn; | |||||
LocalPtr<GradFn> m_fn; | |||||
size_t m_index = 0; | size_t m_index = 0; | ||||
public: | public: | ||||
GradSlotPtr(std::shared_ptr<GradFn> fn, size_t index) : m_fn(fn), m_index(index) {} | |||||
GradSlotPtr(LocalPtr<GradFn> fn, size_t index) : m_fn(fn), m_index(index) {} | |||||
GradSlotPtr() = default; | GradSlotPtr() = default; | ||||
GradSlot* operator->() const { return &m_fn->m_slots[m_index]; } | |||||
GradSlot* operator->() const; | |||||
operator bool() const { return bool(m_fn); } | operator bool() const { return bool(m_fn); } | ||||
@@ -171,7 +149,33 @@ struct ToStringTrait<GradSlotProducerPtr> { | |||||
} | } | ||||
}; | }; | ||||
class GradValue final : public ValueImpl<GradValue> { | |||||
class GradFn { | |||||
private: | |||||
std::weak_ptr<GradKey> m_key; | |||||
SmallVector<GradSlot> m_slots; | |||||
SmallVector<GradSlotProducerPtr> m_dests; | |||||
std::variant<std::monostate, BackwardGraphWithClosure, CustomBackward> m_backward; | |||||
public: | |||||
void clear() { | |||||
m_key.reset(); | |||||
m_slots.clear(); | |||||
m_dests.clear(); | |||||
m_backward.emplace<std::monostate>(); | |||||
} | |||||
std::string to_string() const; | |||||
friend class GradSlotPtr; | |||||
friend class GradKey; | |||||
friend class GradTransformation; | |||||
}; | |||||
inline GradSlot* GradSlotPtr::operator->() const { | |||||
return &m_fn->m_slots[m_index]; | |||||
} | |||||
class GradValue final : public ValueImpl<GradValue, ValueKind::Object> { | |||||
private: | private: | ||||
ValueRef m_value; | ValueRef m_value; | ||||
std::shared_ptr<GradKey> m_key; | std::shared_ptr<GradKey> m_key; | ||||
@@ -179,7 +183,7 @@ private: | |||||
public: | public: | ||||
GradValue(ValueRef value, std::shared_ptr<GradKey> key, GradSlotPtr slot = {}) | GradValue(ValueRef value, std::shared_ptr<GradKey> key, GradSlotPtr slot = {}) | ||||
: m_value(value), m_key(key), m_slot(slot) {} | |||||
: m_value(std::move(value)), m_key(std::move(key)), m_slot(slot) {} | |||||
std::string to_string() const override; | std::string to_string() const override; | ||||
@@ -209,12 +213,13 @@ public: | |||||
class GradKey : public std::enable_shared_from_this<GradKey> { | class GradKey : public std::enable_shared_from_this<GradKey> { | ||||
private: | private: | ||||
std::string m_name; | std::string m_name; | ||||
std::vector<std::pair<std::weak_ptr<GradFn>, std::shared_ptr<OpDef>>> m_tape; | |||||
std::vector<std::pair<std::shared_ptr<GradFn>, std::shared_ptr<OpDef>>> | |||||
m_frozen_tape; | |||||
std::vector<std::pair<LocalWeakPtr<GradFn>, std::shared_ptr<OpDef>>> m_tape; | |||||
std::vector<std::pair<LocalPtr<GradFn>, std::shared_ptr<OpDef>>> m_frozen_tape; | |||||
bool m_frozen = false; | bool m_frozen = false; | ||||
public: | public: | ||||
GradKey() { m_tape.reserve(4 * 1024); } | |||||
void backward(); | void backward(); | ||||
GradValue::ref_t attach(ValueRef tensor, std::function<void(ValueRef)> callback); | GradValue::ref_t attach(ValueRef tensor, std::function<void(ValueRef)> callback); | ||||
const std::string& name() const { return m_name; } | const std::string& name() const { return m_name; } | ||||
@@ -225,7 +230,8 @@ public: | |||||
}; | }; | ||||
class GradKeyValue final | class GradKeyValue final | ||||
: public MixinValueImpl<GradKeyValue, std::shared_ptr<GradKey>> { | |||||
: public MixinValueImpl< | |||||
GradKeyValue, ValueKind::Primitive, std::shared_ptr<GradKey>> { | |||||
public: | public: | ||||
using MixinValueImpl::MixinValueImpl; | using MixinValueImpl::MixinValueImpl; | ||||
@@ -248,7 +254,7 @@ public: | |||||
return tensor; | return tensor; | ||||
} | } | ||||
bool is_grad_value(ValueRef value) { | |||||
bool is_grad_value(const ValueRef& value) { | |||||
if (auto* grad_value = value.as<GradValue>()) { | if (auto* grad_value = value.as<GradValue>()) { | ||||
if (grad_value->has_key(m_key)) { | if (grad_value->has_key(m_key)) { | ||||
return true; | return true; | ||||
@@ -266,13 +272,14 @@ public: | |||||
* \param value | * \param value | ||||
* \return GradValue::ref_t | * \return GradValue::ref_t | ||||
*/ | */ | ||||
GradValue::ref_t as_grad_value(ValueRef value) { | |||||
if (auto grad_value = value.as_ref<GradValue>()) { | |||||
const GradValue::ref_t& as_grad_value(const ValueRef& value) { | |||||
auto&& grad_value = value.as_ref<GradValue>(); | |||||
if (grad_value) { | |||||
if (grad_value->has_key(m_key)) { | if (grad_value->has_key(m_key)) { | ||||
return grad_value; | return grad_value; | ||||
} | } | ||||
} | } | ||||
return {}; | |||||
return GradValue::ref_t::nil; | |||||
} | } | ||||
bool has_key(std::shared_ptr<GradKey> key) { | bool has_key(std::shared_ptr<GradKey> key) { | ||||
@@ -39,7 +39,8 @@ public: | |||||
std::string name() const { return m_name; } | std::string name() const { return m_name; } | ||||
}; | }; | ||||
class LazyEvalValue final : public MixinValueImpl<LazyEvalValue, LazyEvalInfo> { | |||||
class LazyEvalValue final | |||||
: public MixinValueImpl<LazyEvalValue, ValueKind::Object, LazyEvalInfo> { | |||||
public: | public: | ||||
using MixinValueImpl::MixinValueImpl; | using MixinValueImpl::MixinValueImpl; | ||||
@@ -17,7 +17,7 @@ | |||||
namespace mgb::imperative { | namespace mgb::imperative { | ||||
class ScalarValue final : public ValueImpl<ScalarValue> { | |||||
class ScalarValue final : public ValueImpl<ScalarValue, ValueKind::Object> { | |||||
private: | private: | ||||
ValueRef m_value; | ValueRef m_value; | ||||
@@ -22,7 +22,7 @@ | |||||
namespace mgb::imperative { | namespace mgb::imperative { | ||||
class SymbolValue final : public ValueImpl<SymbolValue> { | |||||
class SymbolValue final : public ValueImpl<SymbolValue, ValueKind::Object> { | |||||
private: | private: | ||||
VarNode* m_node = nullptr; | VarNode* m_node = nullptr; | ||||
@@ -111,7 +111,8 @@ public: | |||||
size_t id() const { return m_id; } | size_t id() const { return m_id; } | ||||
}; | }; | ||||
class TracingValue final : public MixinValueImpl<TracingValue, TracingInfo> { | |||||
class TracingValue final | |||||
: public MixinValueImpl<TracingValue, ValueKind::Object, TracingInfo> { | |||||
public: | public: | ||||
using MixinValueImpl::MixinValueImpl; | using MixinValueImpl::MixinValueImpl; | ||||
@@ -256,7 +257,8 @@ public: | |||||
} | } | ||||
}; | }; | ||||
class TracedValue final : public MixinValueImpl<TracedValue, TracedInfo> { | |||||
class TracedValue final | |||||
: public MixinValueImpl<TracedValue, ValueKind::Object, TracedInfo> { | |||||
public: | public: | ||||
using MixinValueImpl::MixinValueImpl; | using MixinValueImpl::MixinValueImpl; | ||||
@@ -0,0 +1,140 @@ | |||||
#pragma once | |||||
#include <chrono> | |||||
#include <iostream> | |||||
#include <string> | |||||
#include <vector> | |||||
namespace mgb { | |||||
namespace imperative { | |||||
namespace stats { | |||||
#define MGE_ENABLE_STATS 0 | |||||
class Timer { | |||||
public: | |||||
using clock_t = std::chrono::system_clock; | |||||
private: | |||||
clock_t::duration m_duration = clock_t::duration{0}; | |||||
size_t m_timing = 0; | |||||
const char* m_name = nullptr; | |||||
uint64_t m_count = 0; | |||||
size_t m_enabled = 1; | |||||
bool m_default_enabled = true; | |||||
struct TimeScopeRecursive { | |||||
Timer& timer; | |||||
clock_t::time_point start; | |||||
bool released = false; | |||||
TimeScopeRecursive(Timer& timer) : timer(timer) { | |||||
if (timer.m_enabled && !timer.m_timing++) { | |||||
start = clock_t::now(); | |||||
} | |||||
} | |||||
~TimeScopeRecursive() { release(); } | |||||
void release() { | |||||
if (released) { | |||||
return; | |||||
} | |||||
if (timer.m_enabled) { | |||||
if (!--timer.m_timing) { | |||||
timer.m_duration += (clock_t::now() - start); | |||||
} | |||||
timer.m_count++; | |||||
} | |||||
released = true; | |||||
} | |||||
}; | |||||
struct EnableScope { | |||||
Timer& timer; | |||||
bool released = false; | |||||
EnableScope(Timer& timer) : timer(timer) { timer.m_enabled++; } | |||||
~EnableScope() { release(); } | |||||
void release() { | |||||
if (released) { | |||||
return; | |||||
} | |||||
timer.m_enabled--; | |||||
released = true; | |||||
} | |||||
}; | |||||
using TimeScope = TimeScopeRecursive; | |||||
public: | |||||
Timer(const char* name, bool default_enabled); | |||||
const char* name() { return m_name; } | |||||
auto time_scope() { return TimeScope(*this); } | |||||
auto time_scope_recursive() { return TimeScopeRecursive(*this); }; | |||||
auto enable_scope() { return EnableScope(*this); } | |||||
void reset() { | |||||
m_duration = clock_t::duration{0}; | |||||
m_count = 0; | |||||
m_enabled = m_default_enabled ? 1 : 0; | |||||
} | |||||
clock_t::duration get() const { return m_duration; } | |||||
uint64_t count() const { return m_count; } | |||||
}; | |||||
} // namespace stats | |||||
struct Stats { | |||||
static inline std::vector<stats::Timer*> sm_timers; | |||||
// register your timers here | |||||
// for example: | |||||
// | |||||
// static inline stats::Timer mytimer; | |||||
// | |||||
// then use MGE_TIMER_SCOPE(mytimer) to collect durations in your code | |||||
static void print() { | |||||
std::vector<const char*> unused_timers; | |||||
for (auto* timer : sm_timers) { | |||||
if (timer->count() == 0) { | |||||
unused_timers.push_back(timer->name()); | |||||
} else { | |||||
printf("%s costs %ld ns, happens %ld times\n", timer->name(), | |||||
timer->get().count(), timer->count()); | |||||
} | |||||
} | |||||
if (!unused_timers.empty()) { | |||||
printf("%zu timers unused\n", unused_timers.size()); | |||||
} | |||||
} | |||||
static void reset() { | |||||
for (auto* timer : sm_timers) { | |||||
timer->reset(); | |||||
} | |||||
} | |||||
}; | |||||
inline stats::Timer::Timer(const char* name, bool default_enabled) | |||||
: m_name(name), m_default_enabled(default_enabled) { | |||||
Stats::sm_timers.push_back(this); | |||||
} | |||||
#if MGE_ENABLE_STATS | |||||
#define MGE_TIMER_SCOPE(name) auto name = Stats::name.time_scope() | |||||
#define MGE_TIMER_SCOPE_RELEASE(name) name.release() | |||||
#define MGE_TIMER_SCOPE_ENABLE(name) auto name = Stats::name.enable_scope() | |||||
#else | |||||
#define MGE_TIMER_SCOPE(name) (void)0 | |||||
#define MGE_TIMER_SCOPE_RELEASE(name) (void)0 | |||||
#define MGE_TIMER_SCOPE_ENABLE(name) (void)0 | |||||
#endif | |||||
} // namespace imperative | |||||
} // namespace mgb |
@@ -23,6 +23,7 @@ | |||||
#include "megbrain/imperative/utils/debug.h" | #include "megbrain/imperative/utils/debug.h" | ||||
#include "megbrain/imperative/utils/local_ptr.h" | #include "megbrain/imperative/utils/local_ptr.h" | ||||
#include "megbrain/imperative/utils/span.h" | #include "megbrain/imperative/utils/span.h" | ||||
#include "megbrain/imperative/utils/stats.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -58,6 +59,11 @@ public: | |||||
inline size_t code() const { return m_code; } | inline size_t code() const { return m_code; } | ||||
}; | }; | ||||
enum class ValueKind { | |||||
Primitive, | |||||
Object, | |||||
}; | |||||
/** | /** | ||||
* \brief an smart reference of value | * \brief an smart reference of value | ||||
* | * | ||||
@@ -129,10 +135,10 @@ public: | |||||
* \return TypedValueRef<TValue> reference if success, otherwise empty reference | * \return TypedValueRef<TValue> reference if success, otherwise empty reference | ||||
*/ | */ | ||||
template <typename TValue> | template <typename TValue> | ||||
inline TypedValueRef<TValue> as_ref(Type<TValue> type = {}) const; | |||||
inline const TypedValueRef<TValue>& as_ref(Type<TValue> type = {}) const; | |||||
template <typename TValue> | template <typename TValue> | ||||
inline TypedValueRef<TValue> cast_ref(Type<TValue> type = {}) const; | |||||
inline const TypedValueRef<TValue>& cast_ref(Type<TValue> type = {}) const; | |||||
template <typename TValue> | template <typename TValue> | ||||
void on_cast_failure() const; | void on_cast_failure() const; | ||||
@@ -161,14 +167,18 @@ public: | |||||
static bool any_watching(); | static bool any_watching(); | ||||
static const ValueRef nil; | |||||
friend class ValueWeakRef; | friend class ValueWeakRef; | ||||
template <typename T> | |||||
template <typename> | |||||
friend class TypedValueRef; | friend class TypedValueRef; | ||||
template <typename T> | |||||
template <typename, ValueKind> | |||||
friend class ValueImpl; | friend class ValueImpl; | ||||
friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs); | friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs); | ||||
}; | }; | ||||
inline const ValueRef ValueRef::nil; | |||||
template <> | template <> | ||||
struct ToStringTrait<ValueRef> { | struct ToStringTrait<ValueRef> { | ||||
public: | public: | ||||
@@ -241,7 +251,7 @@ public: | |||||
friend class ValueRef; | friend class ValueRef; | ||||
friend class ValueWeakRef; | friend class ValueWeakRef; | ||||
template <typename T> | |||||
template <typename, ValueKind> | |||||
friend class ValueImpl; | friend class ValueImpl; | ||||
template <typename T> | template <typename T> | ||||
friend class TypedValueRef; | friend class TypedValueRef; | ||||
@@ -257,7 +267,7 @@ private: | |||||
* | * | ||||
* \tparam T type of value | * \tparam T type of value | ||||
*/ | */ | ||||
template <typename T> | |||||
template <typename T, ValueKind Kind> | |||||
class ValueImpl : public Value { | class ValueImpl : public Value { | ||||
protected: | protected: | ||||
ValueImpl() : Value(TYPE_CODE) {} | ValueImpl() : Value(TYPE_CODE) {} | ||||
@@ -267,6 +277,7 @@ public: | |||||
using weak_ref_t = TypedValueWeakRef<T>; | using weak_ref_t = TypedValueWeakRef<T>; | ||||
static inline const size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); | static inline const size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); | ||||
static constexpr ValueKind KIND = Kind; | |||||
/** | /** | ||||
* \brief helper function for construct a value | * \brief helper function for construct a value | ||||
@@ -288,8 +299,8 @@ public: | |||||
* \tparam T type of value | * \tparam T type of value | ||||
* \tparam TMixin type of mixin class | * \tparam TMixin type of mixin class | ||||
*/ | */ | ||||
template <typename T, typename TMixin> | |||||
class MixinValueImpl : public ValueImpl<T>, public TMixin { | |||||
template <typename T, ValueKind Kind, typename TMixin> | |||||
class MixinValueImpl : public ValueImpl<T, Kind>, public TMixin { | |||||
public: | public: | ||||
using TMixin::TMixin; | using TMixin::TMixin; | ||||
@@ -309,12 +320,14 @@ inline ValueRef::ValueRef(storage_t storage) { | |||||
template <typename TValue> | template <typename TValue> | ||||
inline const TValue* ValueRef::as(Type<TValue> type) const { | inline const TValue* ValueRef::as(Type<TValue> type) const { | ||||
static_assert(std::is_base_of_v<ValueImpl<TValue>, TValue>); | |||||
// auto _ = Stats::time_value_as.time_scope(); | |||||
static_assert(std::is_base_of_v<Value, TValue>); | |||||
return static_cast<const TValue*>(as(type.code())); | return static_cast<const TValue*>(as(type.code())); | ||||
} | } | ||||
template <typename TValue> | template <typename TValue> | ||||
inline const TValue& ValueRef::cast(Type<TValue> type) const { | inline const TValue& ValueRef::cast(Type<TValue> type) const { | ||||
// auto _ = Stats::time_value_cast.time_scope(); | |||||
auto* ptr = as<TValue>(type); | auto* ptr = as<TValue>(type); | ||||
if (mgb_unlikely(!ptr)) { | if (mgb_unlikely(!ptr)) { | ||||
on_cast_failure<TValue>(); | on_cast_failure<TValue>(); | ||||
@@ -324,26 +337,27 @@ inline const TValue& ValueRef::cast(Type<TValue> type) const { | |||||
template <typename TValue> | template <typename TValue> | ||||
inline bool ValueRef::is(Type<TValue> type) const { | inline bool ValueRef::is(Type<TValue> type) const { | ||||
// auto _ = Stats::time_value_is.time_scope(); | |||||
return is(type.code()); | return is(type.code()); | ||||
} | } | ||||
template <typename TValue> | template <typename TValue> | ||||
inline TypedValueRef<TValue> ValueRef::as_ref(Type<TValue> type) const { | |||||
inline const TypedValueRef<TValue>& ValueRef::as_ref(Type<TValue> type) const { | |||||
if (!is<TValue>(type)) { | if (!is<TValue>(type)) { | ||||
return {}; | |||||
return TypedValueRef<TValue>::nil; | |||||
} | } | ||||
return TypedValueRef<TValue>(*this); | |||||
return *reinterpret_cast<const TypedValueRef<TValue>*>(this); | |||||
} | } | ||||
template <typename TValue> | template <typename TValue> | ||||
inline TypedValueRef<TValue> ValueRef::cast_ref(Type<TValue> type) const { | |||||
inline const TypedValueRef<TValue>& ValueRef::cast_ref(Type<TValue> type) const { | |||||
if (!m_storage) { | if (!m_storage) { | ||||
return {}; | |||||
return TypedValueRef<TValue>::nil; | |||||
} | } | ||||
if (mgb_unlikely(!is<TValue>(type))) { | if (mgb_unlikely(!is<TValue>(type))) { | ||||
on_cast_failure<TValue>(); | on_cast_failure<TValue>(); | ||||
} | } | ||||
return TypedValueRef<TValue>(*this); | |||||
return *reinterpret_cast<const TypedValueRef<TValue>*>(this); | |||||
} | } | ||||
template <typename TValue> | template <typename TValue> | ||||
@@ -363,12 +377,31 @@ void ValueRef::on_cast_failure() const { | |||||
template <typename T> | template <typename T> | ||||
class TypedValueRef : public ValueRef { | class TypedValueRef : public ValueRef { | ||||
private: | private: | ||||
TypedValueRef(ValueRef value) : ValueRef(value) {} | |||||
TypedValueRef(ValueRef value) : ValueRef(std::move(value)) {} | |||||
public: | public: | ||||
TypedValueRef() = default; | TypedValueRef() = default; | ||||
const T& operator*() const { return this->template cast<T>(); } | |||||
const T* operator->() const { return this->template as<T>(); } | |||||
const T& operator*() const { | |||||
if constexpr (T::KIND == ValueKind::Object) { | |||||
return this->template cast<T>(); | |||||
} else if constexpr (T::KIND == ValueKind::Primitive) { | |||||
if (!m_storage) { | |||||
on_cast_failure<T>(); | |||||
} | |||||
return static_cast<const T&>(*m_storage); | |||||
} else { | |||||
static_assert(!std::is_same_v<T, T>); | |||||
} | |||||
} | |||||
const T* operator->() const { | |||||
if constexpr (T::KIND == ValueKind::Object) { | |||||
return this->template as<T>(); | |||||
} else if constexpr (T::KIND == ValueKind::Primitive) { | |||||
return static_cast<const T*>(m_storage.get()); | |||||
} else { | |||||
static_assert(!std::is_same_v<T, T>); | |||||
} | |||||
} | |||||
/** | /** | ||||
* \brief reset underlying value to another value | * \brief reset underlying value to another value | ||||
@@ -376,6 +409,7 @@ public: | |||||
* \param successor new value | * \param successor new value | ||||
*/ | */ | ||||
inline void reset(ValueRef successor) { | inline void reset(ValueRef successor) { | ||||
static_assert(T::KIND == ValueKind::Object); | |||||
mgb_assert(m_storage); | mgb_assert(m_storage); | ||||
mgb_assert(!m_storage->m_successor); | mgb_assert(!m_storage->m_successor); | ||||
if (m_storage->m_watching) { | if (m_storage->m_watching) { | ||||
@@ -385,9 +419,11 @@ public: | |||||
m_storage->m_successor = ValueRef(successor.storage()); | m_storage->m_successor = ValueRef(successor.storage()); | ||||
} | } | ||||
static inline const TypedValueRef nil; | |||||
friend class ValueRef; | friend class ValueRef; | ||||
template <typename U> | |||||
template <typename, ValueKind> | |||||
friend class ValueImpl; | friend class ValueImpl; | ||||
}; | }; | ||||
@@ -423,7 +459,7 @@ public: | |||||
ValueRefList() = default; | ValueRefList() = default; | ||||
ValueRefList(size_t nr_elems); | ValueRefList(size_t nr_elems); | ||||
ValueRefList(ValueRef item); | ValueRefList(ValueRef item); | ||||
ValueRefList(std::initializer_list<ValueRef> values); | |||||
// ValueRefList(std::initializer_list<ValueRef> values); | |||||
template <typename TIterator> | template <typename TIterator> | ||||
ValueRefList(TIterator begin, TIterator end); | ValueRefList(TIterator begin, TIterator end); | ||||
ValueRefList(const ValueRefList& rhs); | ValueRefList(const ValueRefList& rhs); | ||||