GitOrigin-RevId: 355da6b814
tags/v1.8.0
@@ -0,0 +1,98 @@ | |||
#include "megbrain/imperative/basic_operators.h" | |||
#include "megbrain/imperative/basic_values.h" | |||
namespace mgb { | |||
namespace imperative { | |||
std::string ApplyOp::to_string() const { | |||
return m_op.to_string(); | |||
} | |||
std::string GetAttr::to_string() const { | |||
std::string buffer; | |||
const char* attr_name = ([&] { | |||
switch (m_attr) { | |||
case None: | |||
return "None"; | |||
case DType: | |||
return "DType"; | |||
case Device: | |||
return "Device"; | |||
case Shape: | |||
return "Shape"; | |||
case Value: | |||
return "Value"; | |||
case Data: | |||
return "Data"; | |||
default: | |||
buffer = std::to_string(m_attr); | |||
return buffer.c_str(); | |||
} | |||
})(); | |||
return ssprintf("GetAttr{attr=%s}", attr_name); | |||
} | |||
CreateTensor::CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape) | |||
: m_kind(kind), m_device(device), m_dtype(dtype), m_shape(shape) {} | |||
CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout) | |||
: m_kind(kind), | |||
m_device(device), | |||
m_dtype(layout.dtype), | |||
m_shape(ValueShape::from(layout)) { | |||
mgb_assert( | |||
layout.is_contiguous() || layout.is_empty(), "layout should be contiguous"); | |||
} | |||
auto CreateTensor::parse(Span<ValueRef> inputs) -> Args { | |||
Args result; | |||
for (auto&& input : inputs) { | |||
if (auto host_storage = input.as_ref<HostStorage>()) { | |||
mgb_assert(!result.host, "duplicated host value"); | |||
result.host.emplace(); | |||
result.host->reset(*host_storage, {shape().as_tensor_shape(), dtype()}); | |||
mgb_assert(result.host->layout().ndim, "invalid shape"); | |||
} else if (auto device_storage = input.as_ref<DeviceStorage>()) { | |||
mgb_assert(!result.device, "duplicated device value"); | |||
result.device.emplace(device(), shape().as_tensor_shape(), dtype()); | |||
result.device->reset(*device_storage, {shape().as_tensor_shape(), dtype()}); | |||
mgb_assert(result.device->layout().ndim, "invalid shape"); | |||
} else { | |||
mgb_throw( | |||
MegBrainError, | |||
"unknown input type, expects HostStorage or DeviceStorage, got " | |||
"%s", | |||
input.name()->c_str()); | |||
} | |||
} | |||
mgb_assert( | |||
result.host || result.device, "require at least one of host/device value"); | |||
result.kind = kind(); | |||
return result; | |||
} | |||
std::string CreateTensor::to_string() const { | |||
return ssprintf( | |||
"CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s}", (int)m_kind, | |||
m_device.to_string().c_str(), m_dtype.name(), m_shape.to_string().c_str()); | |||
} | |||
std::string DTRCommand::to_string() const { | |||
return ssprintf("DTRCommandValue{kind=%d}", (int)m_kind); | |||
} | |||
std::string GetName::to_string() const { | |||
return "GetName{}"; | |||
} | |||
std::string RenameValue::to_string() const { | |||
return ssprintf("RenameValue{name=%s}", imperative::quoted(m_name).c_str()); | |||
} | |||
std::string IsScalar::to_string() const { | |||
return "IsScalar"; | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -0,0 +1,81 @@ | |||
#include "megbrain/imperative/basic_values.h" | |||
namespace mgb { | |||
namespace imperative { | |||
std::string ShapeValue::to_string() const { | |||
return ssprintf("ValueShape%s", ValueShape::to_string().c_str()); | |||
} | |||
std::string CompNodeValue::to_string() const { | |||
return CompNode::to_string(); | |||
} | |||
std::string BoolValue::to_string() const { | |||
return (*m_value) ? "true" : "false"; | |||
} | |||
std::string HostStorage::to_string() const { | |||
return ssprintf("HostStorage{device=%s}", comp_node().to_string().c_str()); | |||
} | |||
std::string DeviceStorage::to_string() const { | |||
return ssprintf("DeviceStorage{device=%s}", comp_node().to_string().c_str()); | |||
} | |||
std::string HostValue::to_string() const { | |||
return ssprintf( | |||
"HostValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(), | |||
m_dtype.name(), m_shape.to_string().c_str()); | |||
} | |||
HostTensorND HostValue::as_nd(bool allow_scalar) const { | |||
HostTensorND nd; | |||
TensorShape tensor_shape; | |||
if (m_shape.is_scalar()) { | |||
mgb_assert(allow_scalar); | |||
tensor_shape = TensorShape{1}; | |||
} else { | |||
tensor_shape = m_shape.as_tensor_shape(); | |||
} | |||
nd.reset(m_storage, {tensor_shape, dtype()}); | |||
return nd; | |||
} | |||
std::string DeviceValue::to_string() const { | |||
return ssprintf( | |||
"DeviceValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(), | |||
m_dtype.name(), m_shape.to_string().c_str()); | |||
} | |||
DeviceTensorND DeviceValue::as_nd(bool allow_scalar) const { | |||
DeviceTensorND nd; | |||
TensorShape tensor_shape; | |||
if (m_shape.is_scalar()) { | |||
mgb_assert(allow_scalar); | |||
tensor_shape = TensorShape{1}; | |||
} else { | |||
tensor_shape = m_shape.as_tensor_shape(); | |||
} | |||
nd.reset(m_storage, {tensor_shape, dtype()}); | |||
return nd; | |||
} | |||
std::string FunctionValue::to_string() const { | |||
return ssprintf("FunctionValue{type=%s}", target_type().name()); | |||
} | |||
std::string DTypeValue::to_string() const { | |||
return DType::name(); | |||
} | |||
std::string StringValue::to_string() const { | |||
return imperative::quoted((std::string&)*this); | |||
} | |||
std::string ErrorValue::to_string() const { | |||
return ssprintf("ErrorValue{message=%s}", message().c_str()); | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -0,0 +1,108 @@ | |||
/** | |||
* \file imperative/src/impl/dispatch.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/dispatch.h" | |||
#include "megbrain/imperative/utils/debug.h" | |||
#include "megbrain/imperative/utils/helper.h" | |||
#include "megbrain/imperative/utils/map.h" | |||
namespace mgb { | |||
namespace imperative { | |||
std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs) { | |||
static bool log_dispatch = MGB_GETENV("MGE_LOG_OP_DISPATCH"); | |||
bool enable_watch = ValueRef::any_watching(); | |||
auto& context = Transformation::get_context(); | |||
size_t& depth = context.next_transformation; | |||
static const char tabs_storage[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t"; | |||
const char* tabs = tabs_storage + sizeof(tabs_storage) / sizeof(char) - depth - 1; | |||
bool log_current_dispatch = log_dispatch; | |||
if (enable_watch) { | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
auto& input = inputs[i]; | |||
if (input.watching()) { | |||
log_current_dispatch = true; | |||
mgb_log_debug("%sinput[%zu] is %s", tabs, i, input.to_string().c_str()); | |||
debug::notify_event("apply"); | |||
} | |||
} | |||
} | |||
// entrance | |||
std::vector<ValueRef> outputs; | |||
if (depth >= context.transformations.size()) { | |||
// fallback | |||
if (log_current_dispatch) { | |||
mgb_log_debug( | |||
"%sfallback apply %s in %s", tabs, op.to_string().c_str(), | |||
imperative::to_string(inputs).c_str()); | |||
} | |||
outputs = op.fallback(inputs); | |||
} else { | |||
// dispatch to stack top | |||
auto& transformation = *context.transformations[depth]; | |||
++depth; | |||
context.frames.push_back({op, inputs}); | |||
CleanupGuard _{[&] { | |||
context.frames.pop_back(); | |||
--depth; | |||
}}; | |||
if (log_current_dispatch) { | |||
mgb_log_debug( | |||
"%s%s apply %s in %s", tabs, transformation.name().c_str(), | |||
op.to_string().c_str(), imperative::to_string(inputs).c_str()); | |||
} | |||
outputs = transformation.apply_transformation(op, inputs); | |||
} | |||
if (log_current_dispatch) { | |||
mgb_log_debug("%sreturn %s", tabs, imperative::to_string(outputs).c_str()); | |||
} | |||
return outputs; | |||
} | |||
std::vector<ValueRef> apply(const OpDef& def, Span<ValueRef> inputs) { | |||
return imperative::apply(ApplyOp{def}, inputs); | |||
} | |||
std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs) { | |||
SmallVector<ValueRef> inputs_storage; | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
inputs_storage.push_back(inputs[i]); | |||
} | |||
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<ValueRef> inputs, | |||
size_t) { | |||
auto outputs = imperative::apply(ApplyOp(*op), inputs); | |||
return SmallVector<ValueRef>(outputs.begin(), outputs.end()); | |||
}; | |||
auto make_const = [](TensorPtr constant) -> ValueRef { | |||
auto host_value = constant->get_value(); | |||
auto device_value = constant->dev_tensor(); | |||
mgb_assert( | |||
host_value.layout().is_contiguous() && | |||
device_value.layout().is_contiguous()); | |||
ValueShape shape; | |||
// FIXME: assume Tensor with shape {1} is scalar | |||
if (!constant->shape().is_scalar()) { | |||
shape = ValueShape::from(constant->shape()); | |||
} | |||
return imperative::apply( | |||
CreateTensor( | |||
CreateTensor::Const, constant->comp_node(), constant->dtype(), | |||
shape), | |||
HostStorage::make(host_value.storage()), | |||
DeviceStorage::make(device_value.storage()))[0]; | |||
}; | |||
auto outputs = graph.apply(inputs_storage, apply_functor, make_const); | |||
return {outputs.begin(), outputs.end()}; | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -0,0 +1,22 @@ | |||
#include "megbrain/imperative/operator.h" | |||
namespace mgb { | |||
namespace imperative { | |||
std::vector<ValueRef> Operator::fallback(Span<ValueRef> inputs) const { | |||
mgb_throw(MegBrainError, "no fallback implementation for %s", to_string().c_str()); | |||
} | |||
size_t Operator::register_type(std::type_index type) { | |||
auto& types = const_cast<std::vector<std::type_index>&>(registered_types()); | |||
types.push_back(type); | |||
return types.size() - 1; | |||
} | |||
const std::vector<std::type_index>& Operator::registered_types() { | |||
static std::vector<std::type_index> sm_registered_types; | |||
return sm_registered_types; | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -0,0 +1,12 @@ | |||
#include "megbrain/imperative/transformation.h" | |||
namespace mgb { | |||
namespace imperative { | |||
TransformationContext& Transformation::get_context() { | |||
thread_local TransformationContext tl_context; | |||
return tl_context; | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -0,0 +1,34 @@ | |||
/** | |||
* \file imperative/src/impl/utils/debug.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <typeindex> | |||
#include "megbrain/imperative/transformation.h" | |||
#include "megbrain/imperative/utils/debug.h" | |||
#include "megbrain/imperative/value.h" | |||
namespace mgb::imperative::debug { | |||
const char* get_type_name(const std::type_info& type) { | |||
return type.name(); | |||
} | |||
const char* get_type_name(const std::type_index& type) { | |||
return type.name(); | |||
} | |||
void notify_event(const char* event) {} | |||
void watch_value(ValueRef value) { | |||
value.watch(); | |||
} | |||
} // namespace mgb::imperative::debug |
@@ -0,0 +1,190 @@ | |||
#include "megbrain/imperative/value.h" | |||
#include "megbrain/imperative/basic_operators.h" | |||
#include "megbrain/imperative/dispatch.h" | |||
#include "megbrain/imperative/utils/map.h" | |||
namespace mgb { | |||
namespace imperative { | |||
namespace { | |||
static thread_local size_t nr_watched_values = 0; | |||
static thread_local uint64_t nr_values = 0; | |||
static thread_local bool recording_values = false; | |||
static thread_local std::vector<ValueWeakRef> recorded_values; | |||
static WeakValueMap<uint64_t, ValueWeakRef> registered_values; | |||
} // namespace | |||
ValueRef::storage_t& ValueRef::storage() const { | |||
if (!m_storage) { | |||
return m_storage; | |||
} | |||
if (auto& storage = m_storage->m_successor.m_storage) { | |||
while (storage->m_successor.m_storage) { | |||
storage = storage->m_successor.m_storage; | |||
} | |||
return storage; | |||
} else { | |||
return m_storage; | |||
} | |||
} | |||
TypedValueRef<DeviceValue> ValueRef::dev_tensor() const { | |||
return imperative::apply(GetAttr(GetAttr::Data), *this)[0].as_ref<DeviceValue>(); | |||
} | |||
TypedValueRef<HostValue> ValueRef::numpy() const { | |||
return imperative::apply(GetAttr(GetAttr::Value), *this)[0].as_ref<HostValue>(); | |||
} | |||
TypedValueRef<CompNodeValue> ValueRef::device() const { | |||
return imperative::apply(GetAttr(GetAttr::Device), *this)[0] | |||
.as_ref<CompNodeValue>(); | |||
} | |||
TypedValueRef<ShapeValue> ValueRef::shape() const { | |||
return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].as_ref<ShapeValue>(); | |||
} | |||
TypedValueRef<DTypeValue> ValueRef::dtype() const { | |||
return imperative::apply(GetAttr(GetAttr::DType), *this)[0].as_ref<DTypeValue>(); | |||
} | |||
TypedValueRef<StringValue> ValueRef::name() const { | |||
return imperative::apply(GetName(), *this)[0].as_ref<StringValue>(); | |||
} | |||
bool ValueRef::is_scalar() const { | |||
return imperative::apply(IsScalar(), *this)[0].cast<BoolValue>(); | |||
} | |||
void ValueRef::watch() const { | |||
mgb_assert(m_storage); | |||
storage()->m_watching++; | |||
nr_watched_values++; | |||
storage()->on_watch(); | |||
// TODO: | |||
// imperative::apply(Watch(), this); | |||
} | |||
void ValueRef::unwatch() const { | |||
mgb_assert(m_storage); | |||
storage()->m_watching--; | |||
nr_watched_values--; | |||
storage()->on_unwatch(); | |||
} | |||
ValueRef ValueRef::unwrap() const { | |||
ValueRef value = *this; | |||
auto& context = Transformation::get_context(); | |||
for (size_t i = 0; i < context.next_transformation; ++i) { | |||
value = context.transformations[i]->unwrap(value); | |||
} | |||
mgb_assert(value); | |||
return value; | |||
} | |||
std::string ValueRef::to_string() const { | |||
if (!m_storage) { | |||
return "<empty value>"; | |||
} | |||
return ssprintf( | |||
"(%zu:%zu) %s", id(), storage()->m_id, storage()->to_string().c_str()); | |||
} | |||
std::string ValueRef::raw_type() const { | |||
if (!m_storage) { | |||
return "null"; | |||
} | |||
auto& types = Value::registered_types(); | |||
mgb_assert(types.size() > m_storage->m_typecode); | |||
return types[m_storage->m_typecode].name(); | |||
} | |||
uint64_t ValueRef::id() const { | |||
return m_storage ? m_storage->m_id : std::numeric_limits<uint64_t>::max(); | |||
} | |||
bool ValueRef::watching() const { | |||
auto storage = this->storage(); | |||
return storage && storage->m_watching; | |||
} | |||
ValueRef ValueRef::make(ValueRef::storage_t storage) { | |||
if (recording_values) { | |||
recorded_values.push_back({storage}); | |||
} | |||
return {storage}; | |||
} | |||
bool ValueRef::any_watching() { | |||
return nr_watched_values != 0; | |||
} | |||
ValueRef ValueWeakRef::lock() { | |||
auto strong_storage = m_storage.lock(); | |||
if ((!strong_storage) || strong_storage->m_successor) { | |||
return {}; | |||
} | |||
return {strong_storage}; | |||
} | |||
Value::Value(size_t typecode) : m_typecode{typecode} { | |||
m_id = nr_values++; | |||
} | |||
Value::~Value() { | |||
if (m_watching) { | |||
debug::notify_event("dtor"); | |||
} | |||
} | |||
size_t Value::register_type(std::type_index type) { | |||
auto& types = const_cast<std::vector<std::type_index>&>(registered_types()); | |||
types.push_back(type); | |||
return types.size() - 1; | |||
} | |||
const std::vector<std::type_index>& Value::registered_types() { | |||
static std::vector<std::type_index> sm_registered_types; | |||
return sm_registered_types; | |||
} | |||
void Value::register_value(ValueRef value) { | |||
registered_values[value.id()] = ValueWeakRef(value); | |||
} | |||
ValueRef Value::get_value_by_id(uint64_t id) { | |||
auto& weak_value = registered_values[id]; | |||
if (auto value = weak_value.lock()) { | |||
return value; | |||
} | |||
return {}; | |||
} | |||
void Value::begin_record_values() { | |||
mgb_assert(!recording_values); | |||
recording_values = true; | |||
recorded_values.clear(); | |||
} | |||
std::vector<ValueRef> Value::end_record_values() { | |||
recording_values = false; | |||
std::vector<ValueRef> recorded_strong_values; | |||
for (auto&& weak_value : recorded_values) { | |||
if (auto value = weak_value.lock()) { | |||
recorded_strong_values.push_back(value); | |||
} | |||
} | |||
return recorded_strong_values; | |||
} | |||
void Value::try_rethrow() { | |||
if (m_typecode == ErrorValue::TYPE_CODE) { | |||
auto message = static_cast<ErrorValue*>(this)->message(); | |||
mgb_throw(MegBrainError, "invalid value: %s", message.c_str()); | |||
} | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -0,0 +1,176 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/basic_operators.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <future> | |||
#include <iomanip> | |||
#include "megbrain/imperative/op_def.h" | |||
#include "megbrain/imperative/operator.h" | |||
#include "megbrain/imperative/utils/helper.h" | |||
#include "megbrain/imperative/utils/value_shape.h" | |||
namespace mgb { | |||
namespace imperative { | |||
class GradKey; | |||
using GenericFunction = std::function<std::vector<ValueRef>(Span<ValueRef>)>; | |||
/** | |||
* \brief apply an OpDef to values | |||
* | |||
*/ | |||
class ApplyOp final : public OperatorImpl<ApplyOp> { | |||
private: | |||
const OpDef& m_op; | |||
public: | |||
ApplyOp(const OpDef& op) : m_op(op) {} | |||
const OpDef& op() { return m_op; } | |||
std::string to_string() const override; | |||
}; | |||
/** | |||
* \brief get an basic attribute from Value | |||
* | |||
*/ | |||
class GetAttr final : public OperatorImpl<GetAttr, Operator::GetAttrLike> { | |||
public: | |||
enum Attr { | |||
None, | |||
DType, | |||
Device, | |||
Shape, | |||
Value, | |||
Data, | |||
}; | |||
private: | |||
Attr m_attr = None; | |||
public: | |||
GetAttr(Attr attr) : m_attr(attr) { | |||
mgb_assert(attr != None, "invalid attr value: None"); | |||
} | |||
Attr attr() const { return m_attr; } | |||
std::string to_string() const; | |||
}; | |||
/** | |||
* \brief create a tensor value from host value or device value | |||
* | |||
*/ | |||
class CreateTensor final : public OperatorImpl<CreateTensor> { | |||
public: | |||
enum Kind { | |||
Common, // common mode, h2d can be cached to speed up | |||
Unique, // require output value to be unqiue (donnot share memory with other | |||
// values) | |||
Const, // put as constant (guaranteed to be same each time) | |||
NoTrace, // won't be trace in any case, would be used in make_backward_graph | |||
// (looking for a better name) | |||
}; | |||
struct Args { | |||
std::optional<HostTensorND> host; | |||
std::optional<DeviceTensorND> device; | |||
Kind kind; | |||
}; | |||
private: | |||
Kind m_kind; | |||
CompNode m_device; | |||
DType m_dtype; | |||
ValueShape m_shape; | |||
public: | |||
CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape); | |||
CreateTensor(Kind kind, CompNode device, TensorLayout layout); | |||
/** | |||
* \brief utility function to unpack args of CreateTensor | |||
* | |||
* \param inputs contains host_storage and device_storage | |||
* \return Args unpacked args | |||
*/ | |||
Args parse(Span<ValueRef> inputs); | |||
Kind kind() const { return m_kind; } | |||
CompNode device() const { return m_device; } | |||
DType dtype() const { return m_dtype; } | |||
ValueShape shape() const { return m_shape; } | |||
std::string to_string() const override; | |||
}; | |||
class DTRCommand final : public OperatorImpl<DTRCommand, Operator::GetAttrLike> { | |||
public: | |||
enum Kind { | |||
None, | |||
Drop, | |||
}; | |||
private: | |||
Kind m_kind = None; | |||
public: | |||
DTRCommand(Kind kind) : m_kind(kind) {} | |||
Kind kind() { return m_kind; } | |||
std::string to_string() const override; | |||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { return {}; } | |||
}; | |||
// deprecated | |||
class GetName final : public OperatorImpl<GetName, Operator::GetAttrLike> { | |||
public: | |||
std::string to_string() const override; | |||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||
return {ValueRef()}; | |||
} | |||
}; | |||
/** | |||
* \brief return a value with new name | |||
* | |||
*/ | |||
class RenameValue : public OperatorImpl<RenameValue, Operator::IdentityLike> { | |||
private: | |||
std::string m_name; | |||
public: | |||
RenameValue(std::string name) : m_name(name) {} | |||
std::string name() const { return m_name; } | |||
std::string to_string() const override; | |||
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||
return {inputs.as_array<1>()[0]}; | |||
} | |||
}; | |||
class IsScalar final : public OperatorImpl<IsScalar, Operator::GetAttrLike> { | |||
private: | |||
public: | |||
std::string to_string() const override; | |||
}; | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -0,0 +1,178 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/basic_values.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <future> | |||
#include <iomanip> | |||
#include "megbrain/imperative/utils/helper.h" | |||
#include "megbrain/imperative/utils/value_shape.h" | |||
#include "megbrain/imperative/value.h" | |||
namespace mgb { | |||
namespace imperative { | |||
class GradKey; | |||
using GenericFunction = std::function<std::vector<ValueRef>(Span<ValueRef>)>; | |||
class ShapeValue final : public MixinValueImpl<ShapeValue, ValueShape> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const override; | |||
}; | |||
class CompNodeValue final : public MixinValueImpl<CompNodeValue, CompNode> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const override; | |||
}; | |||
// TODO: override factory method | |||
class BoolValue final : public ValueImpl<BoolValue> { | |||
private: | |||
std::optional<bool> m_value; | |||
public: | |||
BoolValue(bool value) : m_value{value} {} | |||
operator bool() const { return *m_value; } | |||
std::string to_string() const override; | |||
void clear() override { m_value.reset(); } | |||
}; | |||
class HostStorage final : public MixinValueImpl<HostStorage, HostTensorStorage> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const override; | |||
}; | |||
class DeviceStorage final : public MixinValueImpl<DeviceStorage, DeviceTensorStorage> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const override; | |||
}; | |||
/** | |||
* \brief like HostTensorND mixin, but allow scalar value | |||
* | |||
*/ | |||
class HostValue final : public ValueImpl<HostValue> { | |||
private: | |||
DType m_dtype; | |||
ValueShape m_shape; | |||
HostTensorStorage m_storage; | |||
public: | |||
HostValue(DType dtype, ValueShape shape, HostTensorStorage storage) | |||
: m_dtype(dtype), m_shape(shape), m_storage(storage) {} | |||
HostValue(HostTensorND value) | |||
: HostValue( | |||
value.dtype(), ValueShape::from(value.shape()), value.storage()) { | |||
} | |||
std::string to_string() const override; | |||
void clear() override { | |||
m_dtype = {}; | |||
m_shape = {}; | |||
m_storage = {}; | |||
} | |||
DType dtype() const { return m_dtype; } | |||
ValueShape shape() const { return m_shape; } | |||
CompNode device() const { return m_storage.comp_node(); } | |||
HostTensorStorage storage() const { return m_storage; } | |||
HostTensorND as_nd(bool allow_scalar = false) const; | |||
}; | |||
/** | |||
* \brief like DeviceTensorND mixin, but allow scalar value | |||
* | |||
*/ | |||
class DeviceValue final : public ValueImpl<DeviceValue> { | |||
private: | |||
DType m_dtype; | |||
ValueShape m_shape; | |||
DeviceTensorStorage m_storage; | |||
public: | |||
DeviceValue(DType dtype, ValueShape shape, DeviceTensorStorage storage) | |||
: m_dtype(dtype), m_shape(shape), m_storage(storage) {} | |||
DeviceValue(DeviceTensorND value) | |||
: DeviceValue( | |||
value.dtype(), ValueShape::from(value.shape()), value.storage()) { | |||
} | |||
std::string to_string() const override; | |||
void clear() override { | |||
m_dtype = {}; | |||
m_shape = {}; | |||
m_storage = {}; | |||
} | |||
DType dtype() const { return m_dtype; } | |||
ValueShape shape() const { return m_shape; } | |||
CompNode device() const { return m_storage.comp_node(); } | |||
DeviceTensorStorage storage() const { return m_storage; } | |||
DeviceTensorND as_nd(bool allow_scalar = false) const; | |||
}; | |||
class FunctionValue final : public MixinValueImpl<FunctionValue, GenericFunction> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const override; | |||
}; | |||
class DTypeValue final : public MixinValueImpl<DTypeValue, DType> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const override; | |||
}; | |||
class StringValue final : public MixinValueImpl<StringValue, std::string> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const override; | |||
}; | |||
class Error { | |||
protected: | |||
std::string m_message; | |||
public: | |||
Error() = default; | |||
Error(std::string message) : m_message(message) {} | |||
std::string message() const { return m_message; } | |||
}; | |||
class ErrorValue final : public MixinValueImpl<ErrorValue, Error> { | |||
public: | |||
using MixinValueImpl::MixinValueImpl; | |||
std::string to_string() const override; | |||
}; | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -0,0 +1,72 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/dispatch.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <list> | |||
#include <map> | |||
#include <memory> | |||
#include <typeinfo> | |||
#include <vector> | |||
#include "megbrain/common.h" | |||
#include "megbrain/imperative/basic_operators.h" | |||
#include "megbrain/imperative/basic_values.h" | |||
#include "megbrain/imperative/operator.h" | |||
#include "megbrain/imperative/subgraph.h" | |||
#include "megbrain/imperative/transformation.h" | |||
#include "megbrain/imperative/utils/local_ptr.h" | |||
#include "megbrain/imperative/utils/span.h" | |||
#include "megbrain/imperative/value.h" | |||
namespace mgb { | |||
namespace imperative { | |||
/** | |||
* \brief dispatch entrance, requests would be forwarded to current top transformation | |||
* (or fallback) | |||
* | |||
* \param op | |||
* \param inputs | |||
* \return std::vector<ValueRef> | |||
*/ | |||
std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs); | |||
std::vector<ValueRef> apply(const OpDef& def, Span<ValueRef> inputs); | |||
std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs); | |||
template <typename... TArgs> | |||
constexpr bool is_all_value_ref_v = | |||
(... && (std::is_base_of_v<ValueRef, std::decay_t<TArgs>> || | |||
std::is_same_v<ValueRef, std::decay_t<TArgs>>)); | |||
template <typename T, typename... TArgs> | |||
static auto apply(T&& op, TArgs&&... args) | |||
-> std::enable_if_t<is_all_value_ref_v<TArgs...>, std::vector<ValueRef>> { | |||
ValueRef args_arr[sizeof...(TArgs)] = {std::forward<TArgs&&>(args)...}; | |||
return imperative::apply( | |||
std::forward<T&&>(op), | |||
Span<ValueRef>(std::begin(args_arr), std::end(args_arr))); | |||
} | |||
template <typename T, typename TContainer> | |||
static auto apply(T&& op, TContainer&& container) -> std::enable_if_t< | |||
std::is_same_v< | |||
std::remove_const_t<std::remove_pointer_t<decltype(container.data())>>, | |||
ValueRef> && | |||
std::is_same_v<decltype(container.size()), size_t> && | |||
!std::is_same_v<std::decay_t<TContainer>, Span<ValueRef>>, | |||
std::vector<ValueRef>> { | |||
return imperative::apply( | |||
std::forward<T&&>(op), Span<ValueRef>(container.data(), container.size())); | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -0,0 +1,102 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/operator.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <list> | |||
#include <map> | |||
#include <memory> | |||
#include <typeindex> | |||
#include <typeinfo> | |||
#include <vector> | |||
#include "megbrain/common.h" | |||
#include "megbrain/imperative/utils/span.h" | |||
#include "megbrain/imperative/value.h" | |||
namespace mgb { | |||
namespace imperative { | |||
/** | |||
* \brief base class for all operators | |||
* | |||
*/ | |||
class Operator { | |||
public: | |||
enum Kind { | |||
IdentityLike, // one input, one output, output is like input | |||
GetAttrLike, // no tensor output | |||
Other, | |||
}; | |||
private: | |||
size_t m_typecode; | |||
Kind m_kind; | |||
protected: | |||
Operator(size_t typecode, Kind kind) : m_typecode{typecode}, m_kind{kind} {} | |||
public: | |||
size_t typecode() const { return m_typecode; } | |||
Kind kind() const { return m_kind; } | |||
template <typename U> | |||
U* as() const { | |||
if (m_typecode != U::TYPE_CODE) { | |||
return nullptr; | |||
} | |||
return static_cast<U*>(const_cast<Operator*>(this)); | |||
} | |||
template <typename U> | |||
bool is() const { | |||
return as<U>() != nullptr; | |||
} | |||
template <Kind kKind> | |||
bool is() const { | |||
return kind() == kKind; | |||
} | |||
template <typename U> | |||
U& cast() const { | |||
U* ptr = as<U>(); | |||
mgb_assert(ptr); | |||
return *ptr; | |||
} | |||
virtual std::string to_string() const = 0; | |||
/** | |||
* \brief fallback implementation of this. Not all operators has fallback | |||
* implementation. | |||
* | |||
* \param inputs | |||
* \return std::vector<ValueRef> | |||
*/ | |||
virtual std::vector<ValueRef> fallback(Span<ValueRef> inputs) const; | |||
std::type_index type() const { return registered_types()[m_typecode]; } | |||
static size_t register_type(std::type_index type); | |||
static const std::vector<std::type_index>& registered_types(); | |||
}; | |||
template <typename T, Operator::Kind kKind = Operator::Other> | |||
class OperatorImpl : public Operator { | |||
protected: | |||
OperatorImpl() : Operator(TYPE_CODE, kKind) {} | |||
public: | |||
static inline size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); | |||
std::string to_string() const override = 0; | |||
}; | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -0,0 +1,199 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/transformation.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <list> | |||
#include <map> | |||
#include <memory> | |||
#include <vector> | |||
#include "megbrain/common.h" | |||
#include "megbrain/imperative/subgraph.h" | |||
#include "megbrain/imperative/utils/local_ptr.h" | |||
#include "megbrain/imperative/utils/span.h" | |||
namespace mgb { | |||
namespace imperative { | |||
class ValueRef; | |||
class Operator; | |||
class Transformation; | |||
/** | |||
* \brief args of dispatch action | |||
* | |||
*/ | |||
struct TransformationFrame { | |||
const Operator& op; | |||
const Span<ValueRef>& inputs; | |||
}; | |||
struct TransformationContext { | |||
std::vector<std::shared_ptr<Transformation>> transformations; | |||
std::vector<std::string> scopes; | |||
// TODO: deprecate TransformationGuard, let next_transformation == frames.size() | |||
size_t next_transformation = 0; | |||
std::vector<TransformationFrame> frames; | |||
}; | |||
/** | |||
* \brief Transformation handles operation requests. | |||
* | |||
* There is an transformation stack in each context. When user send an operation | |||
* request, it is firstly passed to the top transformation. When a transformation in the | |||
* stack receiving a request, it should handle it and give a response. Transformations | |||
* are allowed to send requests when handling other requests, those requests would be | |||
* sent to downstairs. A transformation can only be added to one stack. | |||
*/ | |||
class Transformation : public std::enable_shared_from_this<Transformation> { | |||
public: | |||
using pos_t = | |||
decltype(std::declval<TransformationContext>().transformations)::iterator; | |||
class TransformationGuard { | |||
private: | |||
size_t m_priority; | |||
public: | |||
TransformationGuard(size_t priority) : m_priority{priority} { | |||
auto& context = get_context(); | |||
std::swap(m_priority, context.next_transformation); | |||
mgb_assert( | |||
context.next_transformation <= context.transformations.size(), | |||
"invalid priority: %zu vs %zu", context.next_transformation, | |||
context.transformations.size()); | |||
} | |||
~TransformationGuard() { | |||
std::swap(m_priority, get_context().next_transformation); | |||
} | |||
}; | |||
private: | |||
size_t m_priority = std::numeric_limits<size_t>::max(); | |||
public: | |||
/** | |||
* \brief handle a dispatch request | |||
* | |||
* \param op | |||
* \param inputs | |||
* \return std::vector<ValueRef> | |||
*/ | |||
virtual std::vector<ValueRef> apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) = 0; | |||
virtual ValueRef unwrap(ValueRef value) = 0; | |||
virtual std::string name() const = 0; | |||
/** | |||
* \brief called when added to a stack. | |||
*/ | |||
virtual void on_register(){}; | |||
/** | |||
* \brief called when remove from a stack. | |||
* | |||
* Some transformations, like GradTransformation and TraceTransformation, produce | |||
* special values when handling requests. Thus they should recover these values on | |||
* unregistering because other transformations cann't recognize them. | |||
*/ | |||
virtual void on_unregister() noexcept {}; | |||
public: | |||
static auto top() { return get_context().transformations.begin(); } | |||
static auto bottom() { return get_context().transformations.end(); } | |||
static void push_scope(std::string scope) { get_context().scopes.push_back(scope); } | |||
static void pop_scope(std::string scope) { | |||
auto& context = get_context(); | |||
auto top = context.scopes.back(); | |||
context.scopes.pop_back(); | |||
mgb_assert(top == scope); | |||
} | |||
static std::vector<std::string> scopes() { return get_context().scopes; } | |||
/** | |||
* \brief position at transformation stack | |||
* | |||
* \return auto position | |||
*/ | |||
auto pos() const { | |||
mgb_assert( | |||
m_priority != std::numeric_limits<size_t>::max(), "not yet registered"); | |||
return top() + m_priority; | |||
} | |||
/** | |||
* \brief register this at given position | |||
* | |||
* \param pos position | |||
*/ | |||
void register_at(pos_t pos) { | |||
auto& context = get_context(); | |||
mgb_assert( | |||
m_priority == std::numeric_limits<size_t>::max(), "already registered"); | |||
size_t priority = pos - context.transformations.begin(); | |||
for (auto iter = pos; iter != context.transformations.end(); ++iter) { | |||
iter->get()->m_priority++; | |||
} | |||
m_priority = priority; | |||
context.transformations.insert(pos, shared_from_this()); | |||
{ | |||
TransformationGuard _{m_priority + 1}; | |||
on_register(); | |||
} | |||
// assert priority | |||
} | |||
/** | |||
* \brief unregister this from transformation stack | |||
*/ | |||
void unregister() noexcept { | |||
auto& context = get_context(); | |||
mgb_assert( | |||
m_priority != std::numeric_limits<size_t>::max(), "not yet registered"); | |||
{ | |||
TransformationGuard _{m_priority + 1}; | |||
on_unregister(); | |||
} | |||
size_t priority = m_priority; | |||
auto pos = top() + priority; | |||
for (auto iter = pos; iter != context.transformations.end(); ++iter) { | |||
iter->get()->m_priority--; | |||
} | |||
m_priority = std::numeric_limits<size_t>::max(); | |||
context.transformations.erase(pos); | |||
// TODO: assert priority | |||
} | |||
// FIXME: deprecated | |||
[[nodiscard]] TransformationGuard current_level_guard() { return m_priority; } | |||
/** | |||
* \brief swap current context with target | |||
* | |||
* \param context target context | |||
*/ | |||
static void swap_context(TransformationContext& context) { | |||
auto& current_context = get_context(); | |||
std::swap(context.transformations, current_context.transformations); | |||
std::swap(context.scopes, current_context.scopes); | |||
std::swap(context.next_transformation, current_context.next_transformation); | |||
} | |||
static TransformationContext& get_context(); | |||
friend std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs); | |||
friend class ValueRef; | |||
}; | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -0,0 +1,20 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/utils/debug.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <memory> | |||
namespace mgb::imperative::debug { | |||
void notify_event(const char* event); | |||
} |
@@ -0,0 +1,388 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/value.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <list> | |||
#include <map> | |||
#include <memory> | |||
#include <typeindex> | |||
#include <vector> | |||
#include "megbrain/common.h" | |||
#include "megbrain/imperative/subgraph.h" | |||
#include "megbrain/imperative/utils/allocator.h" | |||
#include "megbrain/imperative/utils/debug.h" | |||
#include "megbrain/imperative/utils/local_ptr.h" | |||
#include "megbrain/imperative/utils/span.h" | |||
namespace mgb { | |||
namespace imperative { | |||
class Value; | |||
class ValueRef; | |||
template <typename T> | |||
class TypedValueRef; | |||
template <typename T> | |||
class TypedValueWeakRef; | |||
class Transformation; | |||
class HostValue; | |||
class DeviceValue; | |||
class ShapeValue; | |||
class DTypeValue; | |||
class CompNodeValue; | |||
class StringValue; | |||
class Operator; | |||
/** | |||
* \brief an smart reference of value | |||
* | |||
* An ValueRef is either empty or refers to a value. Values are organized as linked lists | |||
* and only the tail node is valid. ValueRef stores a value node, and it may be | |||
* an invalid internal node. When you dereference it, it will check its successor, | |||
* automatically find the tail node and return. This list would be modified to reduce | |||
* list length by change value's successor, but a ValueRef always has steady m_storage | |||
* when not explicitly modified. | |||
* So we use m_storage to identify a ValueRef ( hash / equility / id ). | |||
*/ | |||
class ValueRef { | |||
public: | |||
using storage_t = LocalPtr<Value>; | |||
protected: | |||
mutable storage_t m_storage; | |||
ValueRef(storage_t storage) { m_storage = storage; } | |||
private: | |||
/** | |||
* \brief recursive get dest value storage and shorten path | |||
* | |||
* \return storage_t dest storage | |||
*/ | |||
storage_t& storage() const; | |||
public: | |||
ValueRef() = default; | |||
/** | |||
* \brief whether value is instance of target type or not | |||
* | |||
* \tparam TValue target type | |||
* \return true if type of value is TValue | |||
* \return false if empty or type of value is not TValue | |||
*/ | |||
template <typename TValue> | |||
bool is() const; | |||
/** | |||
* \brief try cast value as target type | |||
* | |||
* \tparam TValue target type | |||
* \return TValue* raw pointer if success, otherwise nullptr | |||
*/ | |||
template <typename TValue> | |||
const TValue* as() const; | |||
/** | |||
* \brief cast value to target type | |||
* | |||
* \tparam TValue target type | |||
* \return TValue& reference of value | |||
*/ | |||
template <typename TValue> | |||
const TValue& cast() const; | |||
/** | |||
* \brief like as(), but returns TypedValueRef instead | |||
* | |||
* \tparam TValue target type | |||
* \return TypedValueRef<TValue> reference if success, otherwise empty reference | |||
*/ | |||
template <typename TValue> | |||
inline TypedValueRef<TValue> as_ref() const; | |||
operator bool() const { return bool(m_storage); } | |||
TypedValueRef<DeviceValue> dev_tensor() const; | |||
TypedValueRef<HostValue> numpy() const; | |||
TypedValueRef<CompNodeValue> device() const; | |||
TypedValueRef<ShapeValue> shape() const; | |||
TypedValueRef<DTypeValue> dtype() const; | |||
TypedValueRef<StringValue> name() const; | |||
bool is_scalar() const; | |||
void watch() const; | |||
void unwatch() const; | |||
bool watching() const; | |||
ValueRef unwrap() const; | |||
std::string to_string() const; | |||
std::string raw_type() const; | |||
uint64_t id() const; | |||
size_t hash() const { return id(); } | |||
static ValueRef make(storage_t storage); | |||
static bool any_watching(); | |||
friend class ValueWeakRef; | |||
template <typename T> | |||
friend class TypedValueRef; | |||
template <typename T> | |||
friend class ValueImpl; | |||
friend std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs); | |||
}; | |||
template <> | |||
struct ToStringTrait<ValueRef> { | |||
public: | |||
std::string operator()(const ValueRef& value) const { return value.to_string(); } | |||
}; | |||
class ValueWeakRef { | |||
public: | |||
using storage_t = ValueRef::storage_t::weak_type; | |||
protected: | |||
uint64_t m_id = std::numeric_limits<uint64_t>::max(); | |||
mutable storage_t m_storage; | |||
public: | |||
ValueWeakRef() = default; | |||
ValueWeakRef(ValueRef value) : m_id(value.id()), m_storage(value.m_storage) {} | |||
/** | |||
* \brief try promote to ValueRef | |||
* | |||
* \return ValueRef strong ref if value exists, otherwise empty ref | |||
*/ | |||
ValueRef lock(); | |||
size_t hash() const { return m_id; } | |||
bool operator==(const ValueWeakRef& rhs) const { | |||
return m_storage == rhs.m_storage; | |||
} | |||
bool operator!=(const ValueWeakRef& rhs) const { return !(*this == rhs); } | |||
}; | |||
/** | |||
* \brief base class for all generic value involved in dispatch system | |||
* | |||
*/ | |||
class Value : public NonCopyableObj { | |||
private: | |||
uint64_t m_id = std::numeric_limits<uint64_t>::max(); | |||
size_t m_typecode = 0; | |||
ValueRef m_successor; | |||
size_t m_watching = 0; | |||
protected: | |||
Value(size_t typecode); | |||
public: | |||
size_t typecode() const { return m_typecode; } | |||
const std::type_index type() const { return registered_types()[m_typecode]; } | |||
static size_t register_type(std::type_index type); | |||
static const std::vector<std::type_index>& registered_types(); | |||
static void register_value(ValueRef value); | |||
static ValueRef get_value_by_id(uint64_t id); | |||
static void begin_record_values(); | |||
static std::vector<ValueRef> end_record_values(); | |||
virtual std::string to_string() const = 0; | |||
/** | |||
* \brief clear all states of this value | |||
* | |||
*/ | |||
virtual void clear() = 0; | |||
virtual void on_watch() {} | |||
virtual void on_unwatch() {} | |||
friend class ValueRef; | |||
friend class ValueWeakRef; | |||
template <typename T> | |||
friend class ValueImpl; | |||
template <typename T> | |||
friend class TypedValueRef; | |||
~Value(); | |||
private: | |||
void try_rethrow(); | |||
}; | |||
/** | |||
* \brief base class of values, with typecode and factory method support | |||
* | |||
* \tparam T type of value | |||
*/ | |||
template <typename T> | |||
class ValueImpl : public Value { | |||
protected: | |||
ValueImpl() : Value(TYPE_CODE) {} | |||
public: | |||
using ref_t = TypedValueRef<T>; | |||
using weak_ref_t = TypedValueWeakRef<T>; | |||
static inline size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); | |||
/** | |||
* \brief helper function for construct a value | |||
* | |||
* \tparam TArgs types of arguments | |||
* \param args arguments | |||
* \return TypedValueRef<T> reference of value | |||
*/ | |||
template <typename... TArgs> | |||
static TypedValueRef<T> make(TArgs&&... args) { | |||
static_assert(std::is_final_v<T>); | |||
return ValueRef::make(LocalPtr<Value>::make<T>(std::forward<TArgs&&>(args)...)); | |||
} | |||
}; | |||
/** | |||
* \brief base class of values, with mixin support | |||
* | |||
* \tparam T type of value | |||
* \tparam TMixin type of mixin class | |||
*/ | |||
template <typename T, typename TMixin> | |||
class MixinValueImpl : public ValueImpl<T>, public TMixin { | |||
public: | |||
using TMixin::TMixin; | |||
MixinValueImpl(TMixin mixin) : TMixin(std::move(mixin)) {} | |||
public: | |||
void clear() override final { ((TMixin&)*this) = {}; } | |||
bool eq(const TMixin& value) const { return ((const TMixin&)*this) == value; } | |||
}; | |||
template <typename TValue> | |||
const TValue* ValueRef::as() const { | |||
static_assert(std::is_base_of_v<ValueImpl<TValue>, TValue>); | |||
auto storage = this->storage(); | |||
if (!storage) { | |||
return nullptr; | |||
} | |||
if (storage->m_typecode != TValue::TYPE_CODE) { | |||
return nullptr; | |||
} | |||
return static_cast<TValue*>(storage.get()); | |||
} | |||
template <typename TValue> | |||
const TValue& ValueRef::cast() const { | |||
auto* ptr = as<TValue>(); | |||
if (!ptr) { | |||
// if this is ErrorValue, rethrow directly | |||
storage()->try_rethrow(); | |||
mgb_assert( | |||
ptr, "expect type %s, got %s", typeid(TValue).name(), | |||
to_string().c_str()); | |||
} | |||
return *ptr; | |||
} | |||
template <typename TValue> | |||
bool ValueRef::is() const { | |||
auto* ptr = as<TValue>(); | |||
return ptr != nullptr; | |||
} | |||
template <typename TValue> | |||
TypedValueRef<TValue> ValueRef::as_ref() const { | |||
if (!is<TValue>()) { | |||
return {}; | |||
} | |||
return TypedValueRef<TValue>(*this); | |||
} | |||
/** | |||
* \brief ValueRef with concrete type, convenient for dereference | |||
* | |||
* \tparam T type of value | |||
*/ | |||
template <typename T> | |||
class TypedValueRef : public ValueRef { | |||
private: | |||
TypedValueRef(ValueRef value) : ValueRef(value) {} | |||
public: | |||
TypedValueRef() = default; | |||
const T& operator*() const { return this->template cast<T>(); } | |||
const T* operator->() const { return this->template as<T>(); } | |||
/** | |||
* \brief reset underlying value to another value | |||
* | |||
* \param successor new value | |||
*/ | |||
inline void reset(ValueRef successor) { | |||
mgb_assert(m_storage); | |||
mgb_assert(!m_storage->m_successor); | |||
if (m_storage->m_watching) { | |||
debug::notify_event("reset"); | |||
} | |||
m_storage->clear(); | |||
m_storage->m_successor = ValueRef(successor.storage()); | |||
} | |||
friend class ValueRef; | |||
template <typename U> | |||
friend class ValueImpl; | |||
}; | |||
template <typename T> | |||
class TypedValueWeakRef : public ValueWeakRef { | |||
private: | |||
public: | |||
TypedValueWeakRef(ValueRef value) : ValueWeakRef(value) {} | |||
TypedValueWeakRef(ValueWeakRef value) : ValueWeakRef(value) {} | |||
TypedValueRef<T> lock() { return ValueWeakRef::lock().template as_ref<T>(); } | |||
}; | |||
// TODO: add proxy value type, which is meant to be reset in the end | |||
} // namespace imperative | |||
} // namespace mgb | |||
namespace std { | |||
template <> | |||
struct hash<mgb::imperative::ValueWeakRef> { | |||
std::size_t operator()(const mgb::imperative::ValueWeakRef& weak_ref) const { | |||
return weak_ref.hash(); | |||
} | |||
}; | |||
template <> | |||
struct hash<mgb::imperative::ValueRef> { | |||
std::size_t operator()(const mgb::imperative::ValueRef& ref) const { | |||
return ref.hash(); | |||
} | |||
}; | |||
} // namespace std |