GitOrigin-RevId: 860028e1af
tags/v1.9.0
@@ -13,6 +13,7 @@ | |||
#include "megbrain/imperative/transformations/trace.h" | |||
#include "megbrain/imperative/utils/map.h" | |||
#include "megbrain/imperative/utils/stats.h" | |||
#include "./tensor.h" | |||
@@ -21,6 +21,7 @@ | |||
#include "megbrain/imperative/transformations/symbol.h" | |||
#include "megbrain/imperative/transformations/trace.h" | |||
#include "megbrain/imperative/utils/map.h" | |||
#include "megbrain/imperative/utils/stats.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/plugin/profiler.h" | |||
@@ -52,8 +53,48 @@ namespace mgb::imperative::python { | |||
namespace { | |||
WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map; | |||
struct SymbolVarContext { | |||
TransformationContext context; | |||
cg::ComputingGraph* graph; | |||
SymbolVarContext(cg::ComputingGraph* graph) : graph(graph) { | |||
Transformation::swap_context(context); | |||
} | |||
void init() { | |||
std::make_shared<SymbolTransformation>(graph)->register_at( | |||
Transformation::top()); | |||
std::make_shared<ScalarTransformation>()->register_at(Transformation::top()); | |||
} | |||
~SymbolVarContext() { Transformation::swap_context(context); } | |||
}; | |||
ValueRef symvar2val(py::handle py_symbol_var) { | |||
auto* symbol_var = py_symbol_var.cast<PySymbolVar*>(); | |||
ValueRef value = SymbolValue::make(symbol_var->m_node); | |||
if (symbol_var->is_scalar) { | |||
value = ScalarValue::make(value); | |||
} | |||
return value; | |||
} | |||
py::object val2symvar(py::handle typeobj, ValueRef value) { | |||
bool is_scalar = false; | |||
if (auto* scalar_value = value.as<ScalarValue>()) { | |||
value = scalar_value->value(); | |||
is_scalar = true; | |||
} | |||
auto* node = value.cast<SymbolValue>().node(); | |||
auto py_symbol_var = | |||
typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic)); | |||
py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar; | |||
return py_symbol_var; | |||
} | |||
} // namespace | |||
interpreter::Interpreter::Channel* interpreter_for_py = nullptr; | |||
PyTypeObject* py_tensor_type = nullptr; | |||
PyObject *cpp_use_symbolic_shape, *cpp_astensor1d; | |||
@@ -91,36 +132,17 @@ PyObject* py_apply( | |||
if (py::isinstance<PySymbolVar>(py::handle(args[0]))) { | |||
// swap to a special context to reuse scalar handle | |||
TransformationContext symbol_var_context; | |||
Transformation::swap_context(symbol_var_context); | |||
CleanupGuard _{[&] { Transformation::swap_context(symbol_var_context); }}; | |||
auto* graph = | |||
py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph(); | |||
std::make_shared<SymbolTransformation>(graph)->register_at( | |||
Transformation::top()); | |||
std::make_shared<ScalarTransformation>()->register_at( | |||
Transformation::top()); | |||
SymbolVarContext context( | |||
py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph()); | |||
context.init(); | |||
for (size_t i = 0; i < nargs; ++i) { | |||
auto* py_input = py::handle(args[i]).cast<PySymbolVar*>(); | |||
ValueRef input = SymbolValue::make(py_input->m_node); | |||
if (py_input->is_scalar) { | |||
input = ScalarValue::make(input); | |||
} | |||
tensors[i] = input; | |||
tensors[i] = symvar2val(args[i]); | |||
} | |||
auto outputs = imperative::apply(*op, tensors); | |||
auto ret = pybind11::tuple(outputs.size()); | |||
auto typeobj = py::handle(args[0]).get_type(); | |||
for (size_t i = 0; i < outputs.size(); ++i) { | |||
bool is_scalar = false; | |||
if (auto* scalar_value = outputs[i].as<ScalarValue>()) { | |||
outputs[i] = scalar_value->value(); | |||
is_scalar = true; | |||
} | |||
auto* node = outputs[i].cast<SymbolValue>().node(); | |||
ret[i] = typeobj( | |||
pybind11::cast(node, pybind11::return_value_policy::automatic)); | |||
py::handle(ret[i]).cast<PySymbolVar*>()->is_scalar = is_scalar; | |||
ret[i] = val2symvar(typeobj, outputs[i]); | |||
} | |||
return ret.release().ptr(); | |||
} | |||
@@ -1537,17 +1559,29 @@ void init_tensor(py::module m) { | |||
} | |||
}); | |||
m.def("reduce_to_scalar", [](py::object op, py::object tensor) { | |||
auto* tw = TensorWrapper::try_cast(tensor.ptr()); | |||
auto make_scalar_shape = [&](CompNode device) { | |||
return imperative::apply( | |||
CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}), | |||
HostStorage::make(device))[0]; | |||
m.def("reduce_to_scalar", [](py::object op, py::object tensor) -> py::object { | |||
auto reduce_to_scalar = [](const OpDef& op, const ValueRef& input) { | |||
auto make_scalar_shape = [&](CompNode device) { | |||
return imperative::apply( | |||
CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}), | |||
HostStorage::make(device))[0]; | |||
}; | |||
return imperative::apply(op, input, make_scalar_shape(*input.device()))[0]; | |||
}; | |||
auto output = imperative::apply( | |||
*op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data(), | |||
make_scalar_shape(tw->m_tensor->comp_node()))[0]; | |||
return TensorWrapper::make(py_tensor_type, output); | |||
if (py::isinstance<PySymbolVar>(tensor)) { | |||
auto* graph = tensor.cast<PySymbolVar*>()->m_node->owner_graph(); | |||
SymbolVarContext context(graph); | |||
context.init(); | |||
auto output = reduce_to_scalar( | |||
*op.cast<std::shared_ptr<OpDef>>(), symvar2val(tensor)); | |||
auto typeobj = tensor.get_type(); | |||
return val2symvar(typeobj, output); | |||
} else { | |||
auto* tw = TensorWrapper::try_cast(tensor.ptr()); | |||
auto output = reduce_to_scalar( | |||
*op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data()); | |||
return TensorWrapper::make(py_tensor_type, output); | |||
} | |||
}); | |||
m.def("name_tensor", [](std::string name, py::object tensor) { | |||
@@ -1557,7 +1591,7 @@ void init_tensor(py::module m) { | |||
}); | |||
m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool { | |||
ValueRefList values(tensors.size()); | |||
SmallVector<ValueRef> values(tensors.size()); | |||
for (size_t i = 0; i < tensors.size(); ++i) { | |||
values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data(); | |||
} | |||
@@ -1570,17 +1604,16 @@ void init_tensor(py::module m) { | |||
}); | |||
m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object { | |||
ValueRefList values(tensors.size()); | |||
SmallVector<ValueRef> values(tensors.size()); | |||
for (size_t i = 0; i < tensors.size(); ++i) { | |||
values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data(); | |||
} | |||
auto outputs = imperative::apply(GetGradKey(), values); | |||
if (auto* grad_key_val = outputs[0].as<GradKeyValue>()) { | |||
return py::reinterpret_borrow<py::object>( | |||
GradKeyWrapper::wrap_t::pycast(GradKeyWrapper::get(*grad_key_val))); | |||
} else { | |||
auto output = imperative::apply(GetGradKey(), values)[0]; | |||
if (!output) { | |||
return py::none(); | |||
} | |||
return py::reinterpret_borrow<py::object>(GradKeyWrapper::wrap_t::pycast( | |||
GradKeyWrapper::get(output.cast<GradKeyValue>()))); | |||
}); | |||
m.def("set_grad", [](py::object py_key, py::function backward_fn, | |||
@@ -1612,7 +1645,7 @@ void init_tensor(py::module m) { | |||
} | |||
return input_grads; | |||
}; | |||
ValueRefList values(inputs.size() + outputs.size()); | |||
SmallVector<ValueRef> values(inputs.size() + outputs.size()); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data(); | |||
} | |||
@@ -1669,6 +1702,10 @@ void init_tensor(py::module m) { | |||
return reprs; | |||
}); | |||
m.def("print_stats", [] { imperative::Stats::print(); }); | |||
m.def("reset_stats", [] { imperative::Stats::reset(); }); | |||
py::register_exception<TraceError>(m, "TraceError"); | |||
} | |||
@@ -67,7 +67,8 @@ struct TransformationManager { | |||
} | |||
}; | |||
class PyValue final : public MixinValueImpl<PyValue, pybind11::object> { | |||
class PyValue final | |||
: public MixinValueImpl<PyValue, ValueKind::Object, pybind11::object> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
@@ -14,13 +14,9 @@ | |||
#include "megbrain/imperative/utils/debug.h" | |||
#include "megbrain/imperative/utils/helper.h" | |||
#include "megbrain/imperative/utils/map.h" | |||
#include "megbrain/imperative/utils/stats.h" | |||
namespace mgb { | |||
void imperative_log_profile_begin(const char* message); | |||
void imperative_log_profile(const char* message); | |||
void imperative_log_profile_end(const char* message); | |||
namespace imperative { | |||
namespace { | |||
@@ -19,6 +19,7 @@ | |||
#include "megbrain/imperative/ops/backward_graph.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include "megbrain/imperative/utils/stats.h" | |||
#include "megbrain/imperative/utils/to_string.h" | |||
#include "../blob_manager_impl.h" | |||
@@ -1,4 +1,5 @@ | |||
#include "megbrain/imperative/transformation.h" | |||
#include "megbrain/imperative/utils/stats.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -11,6 +11,7 @@ | |||
#include "megbrain/imperative/transformations/eval.h" | |||
#include "megbrain/imperative/transformations/grad.h" | |||
#include "megbrain/imperative/utils/stats.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -40,9 +41,6 @@ ShapeValue::ref_t InterpreterInfo::shape() const { | |||
ValueRefList InterpreterTransformation::apply_op( | |||
const ApplyOp& apply_op, Span<ValueRef> inputs) { | |||
if (apply_op.op().same_type<FastpathCopy>()) { | |||
return {inputs[0]}; | |||
} | |||
SmallVector<Handle> input_handles; | |||
SmallVector<Handle> output_handles; | |||
CleanupGuard _{[&] { | |||
@@ -111,7 +109,11 @@ ValueRefList InterpreterTransformation::apply_create_tensor( | |||
ValueRefList InterpreterTransformation::apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) { | |||
if (auto* op_val = op.as<ApplyOp>()) { | |||
return apply_op(*op_val, inputs); | |||
if (op_val->op().same_type<FastpathCopy>()) { | |||
return inputs[0]; | |||
} else { | |||
return apply_op(*op_val, inputs); | |||
} | |||
} else if (auto* get_attr = op.as<GetAttr>()) { | |||
return apply_get_attr(*get_attr, inputs); | |||
} else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
@@ -11,8 +11,11 @@ | |||
#include "megbrain/imperative/transformations/grad.h" | |||
#include <variant> | |||
#include "megbrain/imperative/graph_cache.h" | |||
#include "megbrain/imperative/resource_manager.h" | |||
#include "megbrain/imperative/utils/stats.h" | |||
#include <range/v3/all.hpp> | |||
@@ -20,20 +23,21 @@ namespace mgb { | |||
namespace imperative { | |||
static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_graph( | |||
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs, | |||
const OpDef& op, Span<ValueRef> inputs, Span<ValueRef> outputs, | |||
Span<bool> inputs_require_grad) { | |||
// hash | |||
using OptimizedBackwardGraphCache = OpMethResultCache< | |||
std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>; | |||
thread_local auto& cache = | |||
*ResourceManager::create_local<OptimizedBackwardGraphCache>(); | |||
OptimizedBackwardGraphCache::key_t cache_key{op}; | |||
OptimizedBackwardGraphCache::key_t cache_key{op.shared_from_this()}; | |||
SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs; | |||
std::get<0>(cache_key.extras) = inputs_require_grad.copy_into<SmallVector<bool>>(); | |||
cache_key.extra<0>() = inputs_require_grad.copy_into<SmallVector<bool>>(); | |||
input_descs.resize(inputs.size()); | |||
// some overhead, consider simplify LogicalTensorDesc | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
input_descs[i].layout.dtype = inputs[i].dtype().cast<DTypeValue>(); | |||
input_descs[i].comp_node = inputs[i].device().cast<CompNodeValue>(); | |||
input_descs[i].layout.dtype = *inputs[i].dtype(); | |||
input_descs[i].comp_node = *inputs[i].device(); | |||
} | |||
auto iter = cache.find(cache_key); | |||
@@ -45,7 +49,7 @@ static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_gra | |||
SmallVector<bool> output_has_grad(outputs.size(), true); | |||
std::shared_ptr<OptimizedBackwardGraphResult> ret; | |||
auto bg = OpDef::make_backward_graph( | |||
*op, input_descs, std::get<0>(cache_key.extras), output_has_grad); | |||
op, input_descs, std::get<0>(cache_key.extras), output_has_grad); | |||
if (!bg.graph.empty()) { | |||
ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | |||
} | |||
@@ -235,7 +239,7 @@ GradValue::ref_t GradKey::attach( | |||
} else { | |||
GradSlotPtr grad_slot; | |||
auto& grad_fn = grad_slot.m_fn; | |||
grad_fn = std::make_shared<GradFn>(); | |||
grad_fn = LocalPtr<GradFn>::make(); | |||
grad_fn->m_key = shared_from_this(); | |||
grad_fn->m_slots.resize(1); | |||
grad_slot.m_index = 0; | |||
@@ -260,17 +264,21 @@ ValueRefList GradTransformation::apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) { | |||
auto fallback = [&] { | |||
ValueRefList unwrapped_inputs(inputs.size()); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (auto grad_value = as_grad_value(inputs[i])) { | |||
unwrapped_inputs[i] = grad_value->m_value; | |||
} else { | |||
unwrapped_inputs[i] = inputs[i]; | |||
{ | |||
// overhead | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (auto&& grad_value = as_grad_value(inputs[i])) { | |||
unwrapped_inputs[i] = grad_value->m_value; | |||
} else { | |||
unwrapped_inputs[i] = inputs[i]; | |||
} | |||
} | |||
} | |||
return imperative::apply(op, unwrapped_inputs); | |||
}; | |||
if (auto* get_attr = op.as<GetAttr>()) { | |||
if (auto grad_value = as_grad_value(inputs.item())) { | |||
if (op.is<GetAttr>()) { | |||
// overhead | |||
if (auto&& grad_value = as_grad_value(inputs.item())) { | |||
return imperative::apply(op, grad_value->m_value); | |||
} else { | |||
return imperative::apply(op, inputs); | |||
@@ -281,28 +289,29 @@ ValueRefList GradTransformation::apply_transformation( | |||
} | |||
if (auto* op_val = op.as<ApplyOp>()) { | |||
size_t nr_require_grad = 0; | |||
SmallVector<bool> require_grads; | |||
for (auto&& input : inputs) { | |||
if (is_grad_value(input)) { | |||
SmallVector<bool> require_grads(inputs.size()); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (is_grad_value(inputs[i])) { | |||
nr_require_grad++; | |||
require_grads.push_back(true); | |||
require_grads[i] = true; | |||
} else { | |||
require_grads.push_back(false); | |||
require_grads[i] = false; | |||
} | |||
} | |||
if (nr_require_grad == 0) { | |||
return imperative::apply(op, inputs); | |||
} | |||
ValueRefList captured_inputs(inputs.size()); | |||
SmallVector<ValueRef> captured_inputs(inputs.size()); | |||
SmallVector<bool> inputs_require_grad(inputs.size()); | |||
// capture value so that trace could assume input as same | |||
auto capture_value = [](ValueRef value) { | |||
auto capture_value = [](const ValueRef& value) { | |||
// TODO: fastpath copy shouldn't be an OpDef | |||
return imperative::apply(ApplyOp(*FastpathCopy::make()), {&value, 1})[0]; | |||
static auto fastpath_copy = FastpathCopy::make(); | |||
return imperative::apply(ApplyOp(*fastpath_copy), value)[0]; | |||
}; | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
auto& input = inputs[i]; | |||
if (auto grad_value = as_grad_value(input)) { | |||
if (auto&& grad_value = as_grad_value(input)) { | |||
captured_inputs[i] = capture_value(grad_value->m_value); | |||
inputs_require_grad[i] = true; | |||
} else { | |||
@@ -310,32 +319,28 @@ ValueRefList GradTransformation::apply_transformation( | |||
inputs_require_grad[i] = false; | |||
} | |||
} | |||
decltype(std::declval<GradFn>().m_backward) backward_storage; | |||
// copy grad_fn->m_backward is expensive | |||
auto grad_fn = LocalPtr<GradFn>::make(); | |||
auto& backward_storage = grad_fn->m_backward; | |||
auto outputs = [&] { | |||
auto backward_rule = | |||
CustomBackward::lookup_grad_rule(op_val->op().dyn_typeinfo()); | |||
if (backward_rule) { | |||
CustomBackward backward; | |||
auto optional_outputs = backward_rule( | |||
op_val->op(), {captured_inputs.data(), captured_inputs.size()}, | |||
{inputs_require_grad.data(), inputs_require_grad.size()}, | |||
backward); | |||
op_val->op(), captured_inputs, inputs_require_grad, backward); | |||
if (optional_outputs) { | |||
backward_storage = backward; | |||
// backward by rule | |||
return *optional_outputs; | |||
} | |||
} | |||
auto outputs = imperative::apply( | |||
op, {captured_inputs.begin(), captured_inputs.end()}); | |||
auto outputs = imperative::apply(op, captured_inputs); | |||
auto backward_graph = make_optimized_backward_graph( | |||
op.cast<ApplyOp>().op().shared_from_this(), | |||
{captured_inputs.begin(), captured_inputs.end()}, | |||
{outputs.data(), outputs.size()}, | |||
{inputs_require_grad.data(), inputs_require_grad.size()}); | |||
op_val->op(), captured_inputs, outputs, inputs_require_grad); | |||
if (backward_graph) { | |||
backward_storage = BackwardGraphWithClosure( | |||
backward_graph, op.cast<ApplyOp>().op().shared_from_this(), | |||
backward_graph, op_val->op().shared_from_this(), | |||
{captured_inputs.begin(), captured_inputs.end()}, | |||
{outputs.data(), outputs.size()}); | |||
// backward by make_backward_graph | |||
@@ -348,18 +353,17 @@ ValueRefList GradTransformation::apply_transformation( | |||
if (std::holds_alternative<std::monostate>(backward_storage)) { | |||
return outputs; | |||
} | |||
auto grad_fn = std::make_shared<GradFn>(); | |||
grad_fn->m_key = m_key; | |||
grad_fn->m_slots.resize(outputs.size()); | |||
grad_fn->m_backward = backward_storage; | |||
mgb_assert(!outputs.empty()); | |||
grad_fn->m_dests.reserve(inputs.size()); | |||
// clang-format off | |||
std::visit([&](auto& backward) { | |||
auto visitor = [&](auto& backward) { | |||
using T = std::decay_t<decltype(backward)>; | |||
if constexpr (std::is_same_v<T, std::monostate>) { | |||
mgb_throw(AssertionError, "invalid backward"); | |||
} else { | |||
// little overhead | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (backward.input_has_grad(i) && require_grads[i]) { | |||
auto& input_grad_slot = | |||
@@ -373,19 +377,23 @@ ValueRefList GradTransformation::apply_transformation( | |||
} | |||
for (size_t i = 0; i < outputs.size(); ++i) { | |||
if (backward.output_requires_grad(i)) { | |||
// little overhead: Value::make | |||
auto grad_value = GradValue::make(outputs[i], m_key, GradSlotPtr{grad_fn, i}); | |||
outputs[i] = record_grad(grad_value); | |||
} | |||
} | |||
} | |||
}, grad_fn->m_backward); | |||
}; | |||
// std::visit may be slightly slower than direct if | |||
std::visit(visitor, backward_storage); | |||
// clang-format on | |||
mgb_assert(!grad_fn->m_slots.empty()); | |||
m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()}); | |||
return outputs; | |||
} else if (op.is<CreateTensor>()) { | |||
return imperative::apply(op, inputs); | |||
} else if (auto* attach_grad = op.as<AttachGrad>()) { | |||
} | |||
if (auto* attach_grad = op.as<AttachGrad>()) { | |||
if (!has_key(attach_grad->key())) { | |||
return fallback(); | |||
} | |||
@@ -408,7 +416,7 @@ ValueRefList GradTransformation::apply_transformation( | |||
return {}; | |||
} else if (auto* is_attached_to = op.as<IsAttachedTo>()) { | |||
if (has_key(is_attached_to->key())) { | |||
if (auto grad_value = as_grad_value(inputs[0])) { | |||
if (auto&& grad_value = as_grad_value(inputs[0])) { | |||
// TODO: assert grad_fn | |||
return {BoolValue::make(true)}; | |||
} | |||
@@ -416,7 +424,7 @@ ValueRefList GradTransformation::apply_transformation( | |||
return {BoolValue::make(false)}; | |||
} else if (auto* set_grad = op.as<SetGrad>()) { | |||
// TODO: merge SetGrad and ApplyOp | |||
auto grad_fn = std::make_shared<GradFn>(); | |||
auto grad_fn = LocalPtr<GradFn>::make(); | |||
auto& backward = | |||
std::get<CustomBackward>(grad_fn->m_backward = CustomBackward()); | |||
size_t nr_inputs = set_grad->nr_inputs(); | |||
@@ -433,7 +441,7 @@ ValueRefList GradTransformation::apply_transformation( | |||
grad_fn->m_slots.resize(nr_outputs); | |||
grad_fn->m_dests.reserve(nr_inputs); | |||
for (size_t i = 0; i < nr_inputs; ++i) { | |||
if (auto grad_value = as_grad_value(inputs_[i])) { | |||
if (auto&& grad_value = as_grad_value(inputs_[i])) { | |||
auto& input_grad_slot = grad_value->m_slot; | |||
grad_fn->m_dests.emplace_back(grad_value->m_slot); | |||
grad_fn->m_dests.back().m_producer_record.insert_after( | |||
@@ -461,21 +469,21 @@ ValueRefList GradTransformation::apply_transformation( | |||
} | |||
return {FunctionValue::make(make_backward_closure(inputs))}; | |||
} else if (op.is<DetachGrad>()) { | |||
if (auto grad_value = as_grad_value(inputs[0])) { | |||
if (auto&& grad_value = as_grad_value(inputs[0])) { | |||
return {grad_value->m_value}; | |||
} else { | |||
return {inputs[0]}; | |||
} | |||
} else if (op.is<GetGradKey>()) { | |||
for (auto&& input : inputs) { | |||
if (auto grad_value = as_grad_value(input)) { | |||
if (auto&& grad_value = as_grad_value(input)) { | |||
return {GradKeyValue::make(grad_value->m_key)}; | |||
} | |||
} | |||
return imperative::apply(op, inputs); | |||
} else if (op.kind() == Operator::IdentityLike) { | |||
mgb_assert(inputs.size() == 1); | |||
if (auto grad_value = as_grad_value(inputs[0])) { | |||
if (auto&& grad_value = as_grad_value(inputs[0])) { | |||
auto output = imperative::apply(op, grad_value->m_value)[0]; | |||
auto grad_output = GradValue::make( | |||
output, grad_value->key(), grad_value->slot_for(m_key)); | |||
@@ -493,7 +501,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) { | |||
auto grad_key = m_key; | |||
std::vector<GradSlotPtr> y_slots; | |||
for (auto&& y : ys) { | |||
if (auto grad_value = as_grad_value(y)) { | |||
if (auto&& grad_value = as_grad_value(y)) { | |||
y_slots.push_back(grad_value->slot_for(grad_key)); | |||
} else { | |||
y_slots.emplace_back(); | |||
@@ -13,6 +13,7 @@ | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include "megbrain/imperative/utils/stats.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -185,7 +186,7 @@ ValueRefList subtensor_rule( | |||
bool is_scalar; | |||
mgb_assert(!inputs_mask[0], "subtensor shouldn't have scalar input"); | |||
if (auto shape = input.shape()) { | |||
size_t ndim = input.shape()->ndim; | |||
size_t ndim = shape->ndim; | |||
for (auto&& [axis, begin, end, step, idx] : subtensor.items) { | |||
if (idx) { | |||
ndim--; | |||
@@ -193,6 +194,7 @@ ValueRefList subtensor_rule( | |||
} | |||
is_scalar = ndim == 0; | |||
} else { | |||
// assume not scalar | |||
is_scalar = false; | |||
} | |||
auto outputs = imperative::apply(subtensor, inputs); | |||
@@ -341,12 +343,16 @@ ValueRefList ScalarTransformation::apply_transformation( | |||
if (auto* get_attr = op.as<GetAttr>()) { | |||
// fastpath for GetAttr | |||
return apply_get_attr(*get_attr, inputs); | |||
} else if (auto* apply_op = op.as<ApplyOp>()) { | |||
if (apply_op->op().same_type<FastpathCopy>()) { | |||
return inputs[0]; | |||
} | |||
} | |||
size_t nr_inputs = inputs.size(); | |||
ValueRefList unwrapped_inputs(nr_inputs); | |||
bool inputs_mask[nr_inputs]; | |||
SmallVector<bool> inputs_mask(nr_inputs); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (auto scalar_value = inputs[i].as_ref<ScalarValue>()) { | |||
if (auto&& scalar_value = inputs[i].as_ref<ScalarValue>()) { | |||
unwrapped_inputs[i] = scalar_value->value(); | |||
inputs_mask[i] = true; | |||
} else { | |||
@@ -358,8 +364,7 @@ ValueRefList ScalarTransformation::apply_transformation( | |||
if (auto apply_op = op.as<ApplyOp>()) { | |||
auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); | |||
if (iter != scalar_rules.end()) { | |||
return iter->second( | |||
apply_op->op(), unwrapped_inputs, {inputs_mask, nr_inputs}); | |||
return iter->second(apply_op->op(), unwrapped_inputs, inputs_mask); | |||
} else { | |||
// TODO: repeat op | |||
return fallback(); | |||
@@ -215,8 +215,8 @@ ValueRefList::ValueRefList(size_t nr_elems) { | |||
init(nr_elems); | |||
} | |||
ValueRefList::ValueRefList(std::initializer_list<ValueRef> values) | |||
: ValueRefList(values.begin(), values.end()) {} | |||
/*ValueRefList::ValueRefList(std::initializer_list<ValueRef> values) | |||
: ValueRefList(values.begin(), values.end()) {}*/ | |||
ValueRefList::ValueRefList(const ValueRefList& rhs) | |||
: ValueRefList(rhs.cbegin(), rhs.cend()) {} | |||
@@ -25,14 +25,16 @@ class GradKey; | |||
using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>; | |||
class ShapeValue final : public MixinValueImpl<ShapeValue, ValueShape> { | |||
class ShapeValue final | |||
: public MixinValueImpl<ShapeValue, ValueKind::Primitive, ValueShape> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const override; | |||
}; | |||
class CompNodeValue final : public MixinValueImpl<CompNodeValue, CompNode> { | |||
class CompNodeValue final | |||
: public MixinValueImpl<CompNodeValue, ValueKind::Primitive, CompNode> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
@@ -40,7 +42,7 @@ public: | |||
}; | |||
// TODO: override factory method | |||
class BoolValue final : public ValueImpl<BoolValue> { | |||
class BoolValue final : public ValueImpl<BoolValue, ValueKind::Primitive> { | |||
private: | |||
std::optional<bool> m_value; | |||
@@ -53,14 +55,17 @@ public: | |||
void clear() override { m_value.reset(); } | |||
}; | |||
class HostStorage final : public MixinValueImpl<HostStorage, HostTensorStorage> { | |||
class HostStorage final | |||
: public MixinValueImpl<HostStorage, ValueKind::Primitive, HostTensorStorage> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const override; | |||
}; | |||
class DeviceStorage final : public MixinValueImpl<DeviceStorage, DeviceTensorStorage> { | |||
class DeviceStorage final | |||
: public MixinValueImpl< | |||
DeviceStorage, ValueKind::Primitive, DeviceTensorStorage> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
@@ -71,7 +76,7 @@ public: | |||
* \brief like HostTensorND mixin, but allow scalar value | |||
* | |||
*/ | |||
class HostValue final : public ValueImpl<HostValue> { | |||
class HostValue final : public ValueImpl<HostValue, ValueKind::Primitive> { | |||
private: | |||
DType m_dtype; | |||
ValueShape m_shape; | |||
@@ -94,9 +99,9 @@ public: | |||
} | |||
DType dtype() const { return m_dtype; } | |||
ValueShape shape() const { return m_shape; } | |||
const ValueShape& shape() const { return m_shape; } | |||
CompNode device() const { return m_storage.comp_node(); } | |||
HostTensorStorage storage() const { return m_storage; } | |||
const HostTensorStorage& storage() const { return m_storage; } | |||
DTypeScalar item() const { | |||
mgb_assert(m_shape.is_scalar()); | |||
return DTypeScalar::make_from_raw(m_dtype, m_storage.ptr()); | |||
@@ -109,7 +114,7 @@ public: | |||
* \brief like DeviceTensorND mixin, but allow scalar value | |||
* | |||
*/ | |||
class DeviceValue final : public ValueImpl<DeviceValue> { | |||
class DeviceValue final : public ValueImpl<DeviceValue, ValueKind::Primitive> { | |||
private: | |||
DType m_dtype; | |||
ValueShape m_shape; | |||
@@ -117,8 +122,8 @@ private: | |||
public: | |||
DeviceValue(DType dtype, ValueShape shape, DeviceTensorStorage storage) | |||
: m_dtype(dtype), m_shape(shape), m_storage(storage) {} | |||
DeviceValue(DeviceTensorND value) | |||
: m_dtype(dtype), m_shape(shape), m_storage(std::move(storage)) {} | |||
DeviceValue(const DeviceTensorND& value) | |||
: DeviceValue( | |||
value.dtype(), ValueShape::from(value.shape()), value.storage()) { | |||
} | |||
@@ -132,28 +137,31 @@ public: | |||
} | |||
DType dtype() const { return m_dtype; } | |||
ValueShape shape() const { return m_shape; } | |||
const ValueShape& shape() const { return m_shape; } | |||
CompNode device() const { return m_storage.comp_node(); } | |||
DeviceTensorStorage storage() const { return m_storage; } | |||
const DeviceTensorStorage& storage() const { return m_storage; } | |||
DeviceTensorND as_nd(bool allow_scalar = false) const; | |||
}; | |||
class FunctionValue final : public MixinValueImpl<FunctionValue, GenericFunction> { | |||
class FunctionValue final | |||
: public MixinValueImpl<FunctionValue, ValueKind::Primitive, GenericFunction> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const override; | |||
}; | |||
class DTypeValue final : public MixinValueImpl<DTypeValue, DType> { | |||
class DTypeValue final | |||
: public MixinValueImpl<DTypeValue, ValueKind::Primitive, DType> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const override; | |||
}; | |||
class StringValue final : public MixinValueImpl<StringValue, std::string> { | |||
class StringValue final | |||
: public MixinValueImpl<StringValue, ValueKind::Primitive, std::string> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
@@ -171,7 +179,8 @@ public: | |||
std::string message() const { return m_message; } | |||
}; | |||
class ErrorValue final : public MixinValueImpl<ErrorValue, Error> { | |||
class ErrorValue final | |||
: public MixinValueImpl<ErrorValue, ValueKind::Primitive, Error> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
@@ -47,9 +47,14 @@ constexpr bool is_all_value_ref_v = | |||
(... && (std::is_base_of_v<ValueRef, std::decay_t<TArgs>> || | |||
std::is_same_v<ValueRef, std::decay_t<TArgs>>)); | |||
template <typename T> | |||
static ValueRefList apply(T&& op, const ValueRef& arg) { | |||
return imperative::apply(std::forward<T&&>(op), Span<ValueRef>{&arg, 1}); | |||
} | |||
template <typename T, typename... TArgs> | |||
static auto apply(T&& op, TArgs&&... args) | |||
-> std::enable_if_t<is_all_value_ref_v<TArgs...>, ValueRefList> { | |||
static auto apply(T&& op, TArgs&&... args) -> std::enable_if_t< | |||
is_all_value_ref_v<TArgs...> && sizeof...(args) != 1, ValueRefList> { | |||
ValueRef args_arr[sizeof...(TArgs)] = {std::forward<TArgs&&>(args)...}; | |||
return imperative::apply( | |||
std::forward<T&&>(op), | |||
@@ -54,6 +54,11 @@ struct OpMethArgs { | |||
return extras == rhs.extras; | |||
} | |||
template <size_t i> | |||
auto& extra() { | |||
return std::get<i>(extras); | |||
} | |||
struct hash_t { | |||
size_t operator()(const OpMethArgs& key) const { return key.hash(); } | |||
}; | |||
@@ -60,7 +60,7 @@ public: | |||
}; | |||
class InterpreterValue final | |||
: public MixinValueImpl<InterpreterValue, InterpreterInfo> { | |||
: public MixinValueImpl<InterpreterValue, ValueKind::Object, InterpreterInfo> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
@@ -104,37 +104,15 @@ struct ToStringTrait<GradSlot> { | |||
std::string operator()(const GradSlot& value) const { return value.to_string(); } | |||
}; | |||
class GradFn { | |||
private: | |||
std::weak_ptr<GradKey> m_key; | |||
std::vector<GradSlot> m_slots; | |||
std::vector<GradSlotProducerPtr> m_dests; | |||
std::variant<std::monostate, BackwardGraphWithClosure, CustomBackward> m_backward; | |||
public: | |||
void clear() { | |||
m_key.reset(); | |||
m_slots.clear(); | |||
m_dests.clear(); | |||
m_backward.emplace<std::monostate>(); | |||
} | |||
std::string to_string() const; | |||
friend class GradSlotPtr; | |||
friend class GradKey; | |||
friend class GradTransformation; | |||
}; | |||
class GradSlotPtr { | |||
private: | |||
std::shared_ptr<GradFn> m_fn; | |||
LocalPtr<GradFn> m_fn; | |||
size_t m_index = 0; | |||
public: | |||
GradSlotPtr(std::shared_ptr<GradFn> fn, size_t index) : m_fn(fn), m_index(index) {} | |||
GradSlotPtr(LocalPtr<GradFn> fn, size_t index) : m_fn(fn), m_index(index) {} | |||
GradSlotPtr() = default; | |||
GradSlot* operator->() const { return &m_fn->m_slots[m_index]; } | |||
GradSlot* operator->() const; | |||
operator bool() const { return bool(m_fn); } | |||
@@ -171,7 +149,33 @@ struct ToStringTrait<GradSlotProducerPtr> { | |||
} | |||
}; | |||
class GradValue final : public ValueImpl<GradValue> { | |||
class GradFn { | |||
private: | |||
std::weak_ptr<GradKey> m_key; | |||
SmallVector<GradSlot> m_slots; | |||
SmallVector<GradSlotProducerPtr> m_dests; | |||
std::variant<std::monostate, BackwardGraphWithClosure, CustomBackward> m_backward; | |||
public: | |||
void clear() { | |||
m_key.reset(); | |||
m_slots.clear(); | |||
m_dests.clear(); | |||
m_backward.emplace<std::monostate>(); | |||
} | |||
std::string to_string() const; | |||
friend class GradSlotPtr; | |||
friend class GradKey; | |||
friend class GradTransformation; | |||
}; | |||
inline GradSlot* GradSlotPtr::operator->() const { | |||
return &m_fn->m_slots[m_index]; | |||
} | |||
class GradValue final : public ValueImpl<GradValue, ValueKind::Object> { | |||
private: | |||
ValueRef m_value; | |||
std::shared_ptr<GradKey> m_key; | |||
@@ -179,7 +183,7 @@ private: | |||
public: | |||
GradValue(ValueRef value, std::shared_ptr<GradKey> key, GradSlotPtr slot = {}) | |||
: m_value(value), m_key(key), m_slot(slot) {} | |||
: m_value(std::move(value)), m_key(std::move(key)), m_slot(slot) {} | |||
std::string to_string() const override; | |||
@@ -209,12 +213,13 @@ public: | |||
class GradKey : public std::enable_shared_from_this<GradKey> { | |||
private: | |||
std::string m_name; | |||
std::vector<std::pair<std::weak_ptr<GradFn>, std::shared_ptr<OpDef>>> m_tape; | |||
std::vector<std::pair<std::shared_ptr<GradFn>, std::shared_ptr<OpDef>>> | |||
m_frozen_tape; | |||
std::vector<std::pair<LocalWeakPtr<GradFn>, std::shared_ptr<OpDef>>> m_tape; | |||
std::vector<std::pair<LocalPtr<GradFn>, std::shared_ptr<OpDef>>> m_frozen_tape; | |||
bool m_frozen = false; | |||
public: | |||
GradKey() { m_tape.reserve(4 * 1024); } | |||
void backward(); | |||
GradValue::ref_t attach(ValueRef tensor, std::function<void(ValueRef)> callback); | |||
const std::string& name() const { return m_name; } | |||
@@ -225,7 +230,8 @@ public: | |||
}; | |||
class GradKeyValue final | |||
: public MixinValueImpl<GradKeyValue, std::shared_ptr<GradKey>> { | |||
: public MixinValueImpl< | |||
GradKeyValue, ValueKind::Primitive, std::shared_ptr<GradKey>> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
@@ -248,7 +254,7 @@ public: | |||
return tensor; | |||
} | |||
bool is_grad_value(ValueRef value) { | |||
bool is_grad_value(const ValueRef& value) { | |||
if (auto* grad_value = value.as<GradValue>()) { | |||
if (grad_value->has_key(m_key)) { | |||
return true; | |||
@@ -266,13 +272,14 @@ public: | |||
* \param value | |||
* \return GradValue::ref_t | |||
*/ | |||
GradValue::ref_t as_grad_value(ValueRef value) { | |||
if (auto grad_value = value.as_ref<GradValue>()) { | |||
const GradValue::ref_t& as_grad_value(const ValueRef& value) { | |||
auto&& grad_value = value.as_ref<GradValue>(); | |||
if (grad_value) { | |||
if (grad_value->has_key(m_key)) { | |||
return grad_value; | |||
} | |||
} | |||
return {}; | |||
return GradValue::ref_t::nil; | |||
} | |||
bool has_key(std::shared_ptr<GradKey> key) { | |||
@@ -39,7 +39,8 @@ public: | |||
std::string name() const { return m_name; } | |||
}; | |||
class LazyEvalValue final : public MixinValueImpl<LazyEvalValue, LazyEvalInfo> { | |||
class LazyEvalValue final | |||
: public MixinValueImpl<LazyEvalValue, ValueKind::Object, LazyEvalInfo> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
@@ -17,7 +17,7 @@ | |||
namespace mgb::imperative { | |||
class ScalarValue final : public ValueImpl<ScalarValue> { | |||
class ScalarValue final : public ValueImpl<ScalarValue, ValueKind::Object> { | |||
private: | |||
ValueRef m_value; | |||
@@ -22,7 +22,7 @@ | |||
namespace mgb::imperative { | |||
class SymbolValue final : public ValueImpl<SymbolValue> { | |||
class SymbolValue final : public ValueImpl<SymbolValue, ValueKind::Object> { | |||
private: | |||
VarNode* m_node = nullptr; | |||
@@ -111,7 +111,8 @@ public: | |||
size_t id() const { return m_id; } | |||
}; | |||
class TracingValue final : public MixinValueImpl<TracingValue, TracingInfo> { | |||
class TracingValue final | |||
: public MixinValueImpl<TracingValue, ValueKind::Object, TracingInfo> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
@@ -256,7 +257,8 @@ public: | |||
} | |||
}; | |||
class TracedValue final : public MixinValueImpl<TracedValue, TracedInfo> { | |||
class TracedValue final | |||
: public MixinValueImpl<TracedValue, ValueKind::Object, TracedInfo> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
@@ -0,0 +1,140 @@ | |||
#pragma once | |||
#include <chrono> | |||
#include <iostream> | |||
#include <string> | |||
#include <vector> | |||
namespace mgb { | |||
namespace imperative { | |||
namespace stats { | |||
#define MGE_ENABLE_STATS 0 | |||
class Timer { | |||
public: | |||
using clock_t = std::chrono::system_clock; | |||
private: | |||
clock_t::duration m_duration = clock_t::duration{0}; | |||
size_t m_timing = 0; | |||
const char* m_name = nullptr; | |||
uint64_t m_count = 0; | |||
size_t m_enabled = 1; | |||
bool m_default_enabled = true; | |||
struct TimeScopeRecursive { | |||
Timer& timer; | |||
clock_t::time_point start; | |||
bool released = false; | |||
TimeScopeRecursive(Timer& timer) : timer(timer) { | |||
if (timer.m_enabled && !timer.m_timing++) { | |||
start = clock_t::now(); | |||
} | |||
} | |||
~TimeScopeRecursive() { release(); } | |||
void release() { | |||
if (released) { | |||
return; | |||
} | |||
if (timer.m_enabled) { | |||
if (!--timer.m_timing) { | |||
timer.m_duration += (clock_t::now() - start); | |||
} | |||
timer.m_count++; | |||
} | |||
released = true; | |||
} | |||
}; | |||
struct EnableScope { | |||
Timer& timer; | |||
bool released = false; | |||
EnableScope(Timer& timer) : timer(timer) { timer.m_enabled++; } | |||
~EnableScope() { release(); } | |||
void release() { | |||
if (released) { | |||
return; | |||
} | |||
timer.m_enabled--; | |||
released = true; | |||
} | |||
}; | |||
using TimeScope = TimeScopeRecursive; | |||
public: | |||
Timer(const char* name, bool default_enabled); | |||
const char* name() { return m_name; } | |||
auto time_scope() { return TimeScope(*this); } | |||
auto time_scope_recursive() { return TimeScopeRecursive(*this); }; | |||
auto enable_scope() { return EnableScope(*this); } | |||
void reset() { | |||
m_duration = clock_t::duration{0}; | |||
m_count = 0; | |||
m_enabled = m_default_enabled ? 1 : 0; | |||
} | |||
clock_t::duration get() const { return m_duration; } | |||
uint64_t count() const { return m_count; } | |||
}; | |||
} // namespace stats | |||
struct Stats { | |||
static inline std::vector<stats::Timer*> sm_timers; | |||
// register your timers here | |||
// for example: | |||
// | |||
// static inline stats::Timer mytimer; | |||
// | |||
// then use MGE_TIMER_SCOPE(mytimer) to collect durations in your code | |||
static void print() { | |||
std::vector<const char*> unused_timers; | |||
for (auto* timer : sm_timers) { | |||
if (timer->count() == 0) { | |||
unused_timers.push_back(timer->name()); | |||
} else { | |||
printf("%s costs %ld ns, happens %ld times\n", timer->name(), | |||
timer->get().count(), timer->count()); | |||
} | |||
} | |||
if (!unused_timers.empty()) { | |||
printf("%zu timers unused\n", unused_timers.size()); | |||
} | |||
} | |||
static void reset() { | |||
for (auto* timer : sm_timers) { | |||
timer->reset(); | |||
} | |||
} | |||
}; | |||
inline stats::Timer::Timer(const char* name, bool default_enabled) | |||
: m_name(name), m_default_enabled(default_enabled) { | |||
Stats::sm_timers.push_back(this); | |||
} | |||
#if MGE_ENABLE_STATS | |||
#define MGE_TIMER_SCOPE(name) auto name = Stats::name.time_scope() | |||
#define MGE_TIMER_SCOPE_RELEASE(name) name.release() | |||
#define MGE_TIMER_SCOPE_ENABLE(name) auto name = Stats::name.enable_scope() | |||
#else | |||
#define MGE_TIMER_SCOPE(name) (void)0 | |||
#define MGE_TIMER_SCOPE_RELEASE(name) (void)0 | |||
#define MGE_TIMER_SCOPE_ENABLE(name) (void)0 | |||
#endif | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -23,6 +23,7 @@ | |||
#include "megbrain/imperative/utils/debug.h" | |||
#include "megbrain/imperative/utils/local_ptr.h" | |||
#include "megbrain/imperative/utils/span.h" | |||
#include "megbrain/imperative/utils/stats.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -58,6 +59,11 @@ public: | |||
inline size_t code() const { return m_code; } | |||
}; | |||
enum class ValueKind { | |||
Primitive, | |||
Object, | |||
}; | |||
/** | |||
* \brief an smart reference of value | |||
* | |||
@@ -129,10 +135,10 @@ public: | |||
* \return TypedValueRef<TValue> reference if success, otherwise empty reference | |||
*/ | |||
template <typename TValue> | |||
inline TypedValueRef<TValue> as_ref(Type<TValue> type = {}) const; | |||
inline const TypedValueRef<TValue>& as_ref(Type<TValue> type = {}) const; | |||
template <typename TValue> | |||
inline TypedValueRef<TValue> cast_ref(Type<TValue> type = {}) const; | |||
inline const TypedValueRef<TValue>& cast_ref(Type<TValue> type = {}) const; | |||
template <typename TValue> | |||
void on_cast_failure() const; | |||
@@ -161,14 +167,18 @@ public: | |||
static bool any_watching(); | |||
static const ValueRef nil; | |||
friend class ValueWeakRef; | |||
template <typename T> | |||
template <typename> | |||
friend class TypedValueRef; | |||
template <typename T> | |||
template <typename, ValueKind> | |||
friend class ValueImpl; | |||
friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs); | |||
}; | |||
inline const ValueRef ValueRef::nil; | |||
template <> | |||
struct ToStringTrait<ValueRef> { | |||
public: | |||
@@ -241,7 +251,7 @@ public: | |||
friend class ValueRef; | |||
friend class ValueWeakRef; | |||
template <typename T> | |||
template <typename, ValueKind> | |||
friend class ValueImpl; | |||
template <typename T> | |||
friend class TypedValueRef; | |||
@@ -257,7 +267,7 @@ private: | |||
* | |||
* \tparam T type of value | |||
*/ | |||
template <typename T> | |||
template <typename T, ValueKind Kind> | |||
class ValueImpl : public Value { | |||
protected: | |||
ValueImpl() : Value(TYPE_CODE) {} | |||
@@ -267,6 +277,7 @@ public: | |||
using weak_ref_t = TypedValueWeakRef<T>; | |||
static inline const size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); | |||
static constexpr ValueKind KIND = Kind; | |||
/** | |||
* \brief helper function for construct a value | |||
@@ -288,8 +299,8 @@ public: | |||
* \tparam T type of value | |||
* \tparam TMixin type of mixin class | |||
*/ | |||
template <typename T, typename TMixin> | |||
class MixinValueImpl : public ValueImpl<T>, public TMixin { | |||
template <typename T, ValueKind Kind, typename TMixin> | |||
class MixinValueImpl : public ValueImpl<T, Kind>, public TMixin { | |||
public: | |||
using TMixin::TMixin; | |||
@@ -309,12 +320,14 @@ inline ValueRef::ValueRef(storage_t storage) { | |||
template <typename TValue> | |||
inline const TValue* ValueRef::as(Type<TValue> type) const { | |||
static_assert(std::is_base_of_v<ValueImpl<TValue>, TValue>); | |||
// auto _ = Stats::time_value_as.time_scope(); | |||
static_assert(std::is_base_of_v<Value, TValue>); | |||
return static_cast<const TValue*>(as(type.code())); | |||
} | |||
template <typename TValue> | |||
inline const TValue& ValueRef::cast(Type<TValue> type) const { | |||
// auto _ = Stats::time_value_cast.time_scope(); | |||
auto* ptr = as<TValue>(type); | |||
if (mgb_unlikely(!ptr)) { | |||
on_cast_failure<TValue>(); | |||
@@ -324,26 +337,27 @@ inline const TValue& ValueRef::cast(Type<TValue> type) const { | |||
template <typename TValue> | |||
inline bool ValueRef::is(Type<TValue> type) const { | |||
// auto _ = Stats::time_value_is.time_scope(); | |||
return is(type.code()); | |||
} | |||
template <typename TValue> | |||
inline TypedValueRef<TValue> ValueRef::as_ref(Type<TValue> type) const { | |||
inline const TypedValueRef<TValue>& ValueRef::as_ref(Type<TValue> type) const { | |||
if (!is<TValue>(type)) { | |||
return {}; | |||
return TypedValueRef<TValue>::nil; | |||
} | |||
return TypedValueRef<TValue>(*this); | |||
return *reinterpret_cast<const TypedValueRef<TValue>*>(this); | |||
} | |||
template <typename TValue> | |||
inline TypedValueRef<TValue> ValueRef::cast_ref(Type<TValue> type) const { | |||
inline const TypedValueRef<TValue>& ValueRef::cast_ref(Type<TValue> type) const { | |||
if (!m_storage) { | |||
return {}; | |||
return TypedValueRef<TValue>::nil; | |||
} | |||
if (mgb_unlikely(!is<TValue>(type))) { | |||
on_cast_failure<TValue>(); | |||
} | |||
return TypedValueRef<TValue>(*this); | |||
return *reinterpret_cast<const TypedValueRef<TValue>*>(this); | |||
} | |||
template <typename TValue> | |||
@@ -363,12 +377,31 @@ void ValueRef::on_cast_failure() const { | |||
template <typename T> | |||
class TypedValueRef : public ValueRef { | |||
private: | |||
TypedValueRef(ValueRef value) : ValueRef(value) {} | |||
TypedValueRef(ValueRef value) : ValueRef(std::move(value)) {} | |||
public: | |||
TypedValueRef() = default; | |||
const T& operator*() const { return this->template cast<T>(); } | |||
const T* operator->() const { return this->template as<T>(); } | |||
const T& operator*() const { | |||
if constexpr (T::KIND == ValueKind::Object) { | |||
return this->template cast<T>(); | |||
} else if constexpr (T::KIND == ValueKind::Primitive) { | |||
if (!m_storage) { | |||
on_cast_failure<T>(); | |||
} | |||
return static_cast<const T&>(*m_storage); | |||
} else { | |||
static_assert(!std::is_same_v<T, T>); | |||
} | |||
} | |||
const T* operator->() const { | |||
if constexpr (T::KIND == ValueKind::Object) { | |||
return this->template as<T>(); | |||
} else if constexpr (T::KIND == ValueKind::Primitive) { | |||
return static_cast<const T*>(m_storage.get()); | |||
} else { | |||
static_assert(!std::is_same_v<T, T>); | |||
} | |||
} | |||
/** | |||
* \brief reset underlying value to another value | |||
@@ -376,6 +409,7 @@ public: | |||
* \param successor new value | |||
*/ | |||
inline void reset(ValueRef successor) { | |||
static_assert(T::KIND == ValueKind::Object); | |||
mgb_assert(m_storage); | |||
mgb_assert(!m_storage->m_successor); | |||
if (m_storage->m_watching) { | |||
@@ -385,9 +419,11 @@ public: | |||
m_storage->m_successor = ValueRef(successor.storage()); | |||
} | |||
static inline const TypedValueRef nil; | |||
friend class ValueRef; | |||
template <typename U> | |||
template <typename, ValueKind> | |||
friend class ValueImpl; | |||
}; | |||
@@ -423,7 +459,7 @@ public: | |||
ValueRefList() = default; | |||
ValueRefList(size_t nr_elems); | |||
ValueRefList(ValueRef item); | |||
ValueRefList(std::initializer_list<ValueRef> values); | |||
// ValueRefList(std::initializer_list<ValueRef> values); | |||
template <typename TIterator> | |||
ValueRefList(TIterator begin, TIterator end); | |||
ValueRefList(const ValueRefList& rhs); | |||