|
@@ -0,0 +1,131 @@ |
|
|
|
|
|
/** |
|
|
|
|
|
* \file imperative/src/include/megbrain/imperative/symbol.h |
|
|
|
|
|
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") |
|
|
|
|
|
* |
|
|
|
|
|
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. |
|
|
|
|
|
* |
|
|
|
|
|
* Unless required by applicable law or agreed to in writing, |
|
|
|
|
|
* software distributed under the License is distributed on an |
|
|
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
|
*/ |
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
|
|
|
|
|
|
|
|
#include <future> |
|
|
|
|
|
#include <variant> |
|
|
|
|
|
|
|
|
|
|
|
#include "megbrain/imperative/dispatch.h" |
|
|
|
|
|
#include "megbrain/imperative/interpreter.h" |
|
|
|
|
|
#include "megbrain/imperative/opr_utility.h" |
|
|
|
|
|
#include "megbrain/imperative/utils/helper.h" |
|
|
|
|
|
#include "megbrain/opr/io.h" |
|
|
|
|
|
|
|
|
|
|
|
namespace mgb::imperative { |
|
|
|
|
|
|
|
|
|
|
|
class SymbolValue final : public ValueImpl<SymbolValue> { |
|
|
|
|
|
private: |
|
|
|
|
|
VarNode* m_node = nullptr; |
|
|
|
|
|
|
|
|
|
|
|
public: |
|
|
|
|
|
SymbolValue(VarNode* node) : m_node(node) {} |
|
|
|
|
|
|
|
|
|
|
|
VarNode* node() const { return m_node; } |
|
|
|
|
|
|
|
|
|
|
|
std::string to_string() const override { return ssprintf("VarNode{%p}", m_node); } |
|
|
|
|
|
|
|
|
|
|
|
void clear() override { m_node = nullptr; } |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
|
* \brief this transformation is used to handle VarNode. |
|
|
|
|
|
* |
|
|
|
|
|
* Unlike other transformations, this transformation is not used in Tensor evaluation. |
|
|
|
|
|
* when user calls py_apply(SymbolVar), we'll switch current transformation context to a |
|
|
|
|
|
* special symbol context. The advantage is that we can handle scalar by |
|
|
|
|
|
* ScalarTransformation. |
|
|
|
|
|
*/ |
|
|
|
|
|
class SymbolTransformation final : public Transformation { |
|
|
|
|
|
private: |
|
|
|
|
|
ComputingGraph* m_graph = nullptr; |
|
|
|
|
|
|
|
|
|
|
|
public: |
|
|
|
|
|
SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {} |
|
|
|
|
|
std::vector<ValueRef> apply_transformation( |
|
|
|
|
|
const Operator& op, Span<ValueRef> inputs) override { |
|
|
|
|
|
if (auto* apply_op = op.as<ApplyOp>()) { |
|
|
|
|
|
SmallVector<VarNode*> input_nodes; |
|
|
|
|
|
for (auto&& input : inputs) { |
|
|
|
|
|
input_nodes.push_back(input.cast<SymbolValue>().node()); |
|
|
|
|
|
} |
|
|
|
|
|
auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); |
|
|
|
|
|
std::vector<ValueRef> outputs; |
|
|
|
|
|
for (auto&& output_node : output_nodes) { |
|
|
|
|
|
outputs.push_back(SymbolValue::make(output_node)); |
|
|
|
|
|
} |
|
|
|
|
|
return outputs; |
|
|
|
|
|
} else if (auto* create_tensor = op.as<CreateTensor>()) { |
|
|
|
|
|
auto&& args = create_tensor->parse(inputs); |
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
args.kind == CreateTensor::Const, |
|
|
|
|
|
"only const value is allowed here"); |
|
|
|
|
|
auto* node = opr::ImmutableTensor::make(*m_graph, *args.host, {}).node(); |
|
|
|
|
|
return {SymbolValue::make(node)}; |
|
|
|
|
|
} else if (auto* get_attr = op.as<GetAttr>()) { |
|
|
|
|
|
auto* node = inputs.as_array<1>()[0].cast<SymbolValue>().node(); |
|
|
|
|
|
switch (get_attr->attr()) { |
|
|
|
|
|
case GetAttr::DType: |
|
|
|
|
|
return {DTypeValue::make(node->dtype())}; |
|
|
|
|
|
case GetAttr::Device: |
|
|
|
|
|
return {CompNodeValue::make(node->comp_node())}; |
|
|
|
|
|
case GetAttr::Shape: { |
|
|
|
|
|
if (!cg::is_static_var_shape(node)) { |
|
|
|
|
|
mgb_log_debug( |
|
|
|
|
|
"shape inference invalid for %s", node->name().c_str()); |
|
|
|
|
|
return {ValueRef()}; |
|
|
|
|
|
} |
|
|
|
|
|
auto shape = m_graph->static_infer_manager().infer_shape(node); |
|
|
|
|
|
return {ShapeValue::make(ValueShape::from(shape))}; |
|
|
|
|
|
} |
|
|
|
|
|
case GetAttr::Value: { |
|
|
|
|
|
if (!cg::is_static_var_value(node)) { |
|
|
|
|
|
mgb_log_debug( |
|
|
|
|
|
"value inference invalid for %s", node->name().c_str()); |
|
|
|
|
|
return {ValueRef()}; |
|
|
|
|
|
} |
|
|
|
|
|
auto inferred_value = |
|
|
|
|
|
m_graph->static_infer_manager().infer_value(node); |
|
|
|
|
|
HostTensorND host_value(node->comp_node(), node->dtype()); |
|
|
|
|
|
host_value.copy_from(inferred_value); |
|
|
|
|
|
return {HostValue::make(host_value)}; |
|
|
|
|
|
} |
|
|
|
|
|
case GetAttr::Data: { |
|
|
|
|
|
if (!cg::is_static_var_value(node)) { |
|
|
|
|
|
mgb_log_debug( |
|
|
|
|
|
"value inference invalid for %s", node->name().c_str()); |
|
|
|
|
|
return {ValueRef()}; |
|
|
|
|
|
} |
|
|
|
|
|
auto inferred_value = |
|
|
|
|
|
m_graph->static_infer_manager().infer_value(node); |
|
|
|
|
|
DeviceTensorND dev_value(node->comp_node(), node->dtype()); |
|
|
|
|
|
dev_value.copy_from(inferred_value); |
|
|
|
|
|
return {DeviceValue::make(dev_value)}; |
|
|
|
|
|
} |
|
|
|
|
|
default: |
|
|
|
|
|
mgb_throw( |
|
|
|
|
|
MegBrainError, "Symbol: malformed GetAttr: %s", |
|
|
|
|
|
op.to_string().c_str()); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
return op.fallback(inputs); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ValueRef unwrap(ValueRef value) override { |
|
|
|
|
|
mgb_assert(!value.is<SymbolValue>(), "SymbolValue doesn't support unwrap"); |
|
|
|
|
|
return value; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::string name() const override { return "SymbolTransformation"; } |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
} // namespace mgb::imperative |