@@ -309,7 +309,7 @@ inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { | |||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
inputs.push_back(args[i]->shared_from_this()); | inputs.push_back(args[i]->shared_from_this()); | ||||
} | } | ||||
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs) { | |||||
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs, size_t) { | |||||
return apply(op, std::move(inputs)); | return apply(op, std::move(inputs)); | ||||
}; | }; | ||||
return graph.apply(inputs, apply_functor, &make_const); | return graph.apply(inputs, apply_functor, &make_const); | ||||
@@ -317,7 +317,7 @@ inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { | |||||
template <typename T> | template <typename T> | ||||
auto apply(Subgraph graph, T&& tensors) | auto apply(Subgraph graph, T&& tensors) | ||||
-> std::enable_if_t<std::is_same_v<decltype(tensors[0]), Tensor*>, | |||||
-> std::enable_if_t<std::is_same_v<std::decay_t<decltype(tensors[0])>, Tensor*>, | |||||
apply_result_t> { | apply_result_t> { | ||||
size_t nargs = tensors.size(); | size_t nargs = tensors.size(); | ||||
Tensor* args[nargs]; | Tensor* args[nargs]; | ||||
@@ -0,0 +1,105 @@ | |||||
/** | |||||
* \file imperative/src/impl/subgraph.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/imperative/subgraph.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
void Subgraph::remove_unused_exprs() { | |||||
std::unordered_set<size_t> required_vars = {outputs.begin(), outputs.end()}; | |||||
required_vars.erase(0); | |||||
for (auto iter = exprs.rbegin(); iter != exprs.rend(); ++iter) { | |||||
auto& expr = *iter; | |||||
bool required = false; | |||||
for (auto output : expr.outputs) { | |||||
if (required_vars.count(output)) { | |||||
required = true; | |||||
break; | |||||
} | |||||
} | |||||
if (required) { | |||||
required_vars.insert(expr.inputs.begin(), expr.inputs.end()); | |||||
} else { | |||||
expr.op = nullptr; | |||||
} | |||||
} | |||||
exprs.erase(std::remove_if(exprs.begin(), exprs.end(), | |||||
[](auto expr) { return expr.op == nullptr; }), | |||||
exprs.end()); | |||||
} | |||||
SmallVector<bool> Subgraph::gen_input_mask() { | |||||
std::unordered_set<size_t> unused_inputs = {inputs.begin(), inputs.end()}; | |||||
for (auto&& expr : exprs) { | |||||
for (auto&& input : expr.inputs) { | |||||
unused_inputs.erase(input); | |||||
} | |||||
} | |||||
for (auto&& output : outputs) { | |||||
unused_inputs.erase(output); | |||||
} | |||||
unused_inputs.insert(0); | |||||
SmallVector<bool> mask(inputs.size(), true); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (unused_inputs.count(inputs[i])) { | |||||
mask[i] = false; | |||||
} | |||||
} | |||||
return mask; | |||||
} | |||||
SmallVector<bool> Subgraph::gen_output_mask() { | |||||
std::unordered_set<size_t> invalid_outputs = {outputs.begin(), | |||||
outputs.end()}; | |||||
for (auto&& input : inputs) { | |||||
invalid_outputs.erase(input); | |||||
} | |||||
for (auto&& expr : exprs) { | |||||
for (auto&& output : expr.outputs) { | |||||
invalid_outputs.erase(output); | |||||
} | |||||
} | |||||
for (auto&& constant: constants) { | |||||
invalid_outputs.erase(constant.first); | |||||
} | |||||
invalid_outputs.insert(0); | |||||
SmallVector<bool> mask(outputs.size(), true); | |||||
for (size_t i = 0; i < outputs.size(); ++i) { | |||||
if (invalid_outputs.count(outputs[i])) { | |||||
mask[i] = false; | |||||
} | |||||
} | |||||
return mask; | |||||
} | |||||
void Subgraph::replace_vars( | |||||
const std::unordered_map<size_t, size_t>& replace_map) { | |||||
// FIXME: preprocess replace_map | |||||
auto replace_var = [&](var_t& var) { | |||||
// TODO: detect infinite loop | |||||
while (replace_map.count(var)) { | |||||
var = replace_map.at(var); | |||||
} | |||||
}; | |||||
for (auto& expr : exprs) { | |||||
for (auto& input : expr.inputs) { | |||||
replace_var(input); | |||||
} | |||||
} | |||||
for (auto& output : outputs) { | |||||
replace_var(output); | |||||
} | |||||
} | |||||
} // namespace imperative | |||||
} // namespace mgb |
@@ -14,6 +14,7 @@ | |||||
#include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
#include "megbrain/imperative/physical_tensor.h" | #include "megbrain/imperative/physical_tensor.h" | ||||
#include "megbrain/imperative/utils/to_string.h" | #include "megbrain/imperative/utils/to_string.h" | ||||
#include "megbrain/imperative/subgraph.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -28,54 +29,6 @@ enum DispatchMode { | |||||
using SharedOp = std::shared_ptr<OpDef>; | using SharedOp = std::shared_ptr<OpDef>; | ||||
template <typename T> | |||||
struct Expr { | |||||
std::shared_ptr<OpDef> op; | |||||
SmallVector<T> inputs; | |||||
SmallVector<T> outputs; | |||||
}; | |||||
struct Subgraph { | |||||
SmallVector<size_t> inputs; | |||||
SmallVector<std::pair<size_t, TensorPtr>> constants; | |||||
SmallVector<size_t> outputs; | |||||
SmallVector<Expr<size_t>> exprs; | |||||
template <typename T, typename F, typename C> | |||||
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||||
std::unordered_map<size_t, T> idx2var; | |||||
mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
idx2var[inputs[i]] = input_vars[i]; | |||||
} | |||||
for (auto&& [idx, val]: constants) { | |||||
idx2var[idx] = c(val); | |||||
} | |||||
for (auto& expr: exprs) { | |||||
SmallVector<T> expr_inputs; | |||||
for (auto idx: expr.inputs) { | |||||
expr_inputs.push_back(idx2var[idx]); | |||||
} | |||||
SmallVector<T> expr_outputs = f(expr.op, std::move(expr_inputs)); | |||||
mgb_assert(expr_outputs.size() == expr.outputs.size(), "output size mismatch"); | |||||
for (size_t i = 0; i < expr_outputs.size(); ++i) { | |||||
idx2var[expr.outputs[i]] = expr_outputs[i]; | |||||
} | |||||
} | |||||
SmallVector<T> output_vars; | |||||
for (auto idx: outputs) { | |||||
output_vars.push_back(idx2var[idx]); | |||||
} | |||||
return output_vars; | |||||
} | |||||
bool empty() const { | |||||
return outputs.size() == 0; | |||||
} | |||||
std::string repr() const; | |||||
}; | |||||
struct BackwardGraphResult { | struct BackwardGraphResult { | ||||
Subgraph backward; | Subgraph backward; | ||||
SmallVector<bool> save_for_backward; | SmallVector<bool> save_for_backward; | ||||
@@ -0,0 +1,100 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/subgraph.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <list> | |||||
#include "megbrain/imperative/physical_tensor.h" | |||||
#include "megbrain/imperative/utils/to_string.h" | |||||
#include "megbrain/utils/small_vector.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
class OpDef; | |||||
template <typename T> | |||||
struct Expr { | |||||
std::shared_ptr<OpDef> op; | |||||
SmallVector<T> inputs; | |||||
SmallVector<T> outputs; | |||||
}; | |||||
template <typename T> | |||||
struct ToStringTrait<Expr<T>> { | |||||
std::string operator()(const Expr<T>& expr) { | |||||
return ssprintf("%s = %s %s\n", to_string(expr.inputs).c_str(), to_string(expr.op.get()).c_str(), to_string(expr.outputs).c_str()); | |||||
} | |||||
}; | |||||
struct Subgraph { | |||||
template <typename TDesc> | |||||
class Builder; | |||||
using var_t = size_t; | |||||
using vars_t = SmallVector<size_t>; | |||||
using op_t = std::shared_ptr<OpDef>; | |||||
using expr_t = Expr<var_t>; | |||||
template <typename TDesc> | |||||
using builder_t = Builder<TDesc>; | |||||
SmallVector<var_t> inputs; | |||||
SmallVector<std::pair<var_t, TensorPtr>> constants; | |||||
SmallVector<var_t> outputs; | |||||
SmallVector<expr_t> exprs; | |||||
template <typename T, typename F, typename C> | |||||
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||||
std::unordered_map<size_t, T> idx2var; | |||||
mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
idx2var[inputs[i]] = input_vars[i]; | |||||
} | |||||
for (auto&& [idx, val] : constants) { | |||||
idx2var[idx] = c(val); | |||||
} | |||||
for (auto& expr : exprs) { | |||||
SmallVector<T> expr_inputs; | |||||
for (auto idx : expr.inputs) { | |||||
expr_inputs.push_back(idx2var[idx]); | |||||
} | |||||
SmallVector<T> expr_outputs = | |||||
f(expr.op, std::move(expr_inputs), expr.outputs.size()); | |||||
mgb_assert(expr_outputs.size() == expr.outputs.size(), | |||||
"output size mismatch"); | |||||
for (size_t i = 0; i < expr_outputs.size(); ++i) { | |||||
idx2var[expr.outputs[i]] = expr_outputs[i]; | |||||
} | |||||
} | |||||
SmallVector<T> output_vars; | |||||
for (auto idx : outputs) { | |||||
output_vars.push_back(idx2var[idx]); | |||||
} | |||||
return output_vars; | |||||
} | |||||
void remove_unused_exprs(); | |||||
SmallVector<bool> gen_input_mask(); | |||||
SmallVector<bool> gen_output_mask(); | |||||
bool empty() const { return outputs.size() == 0; } | |||||
void replace_vars(const std::unordered_map<size_t, size_t>& replace_map); | |||||
std::string repr() const; | |||||
bool is_single() const; | |||||
std::shared_ptr<OpDef> as_single() const; | |||||
bool operator==(const Subgraph& rhs) const; | |||||
}; | |||||
} // namespace imperative | |||||
} // namespace mgb |