@@ -309,7 +309,7 @@ inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { | |||
for (size_t i = 0; i < nargs; ++i) { | |||
inputs.push_back(args[i]->shared_from_this()); | |||
} | |||
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs) { | |||
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs, size_t) { | |||
return apply(op, std::move(inputs)); | |||
}; | |||
return graph.apply(inputs, apply_functor, &make_const); | |||
@@ -317,7 +317,7 @@ inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { | |||
template <typename T> | |||
auto apply(Subgraph graph, T&& tensors) | |||
-> std::enable_if_t<std::is_same_v<decltype(tensors[0]), Tensor*>, | |||
-> std::enable_if_t<std::is_same_v<std::decay_t<decltype(tensors[0])>, Tensor*>, | |||
apply_result_t> { | |||
size_t nargs = tensors.size(); | |||
Tensor* args[nargs]; | |||
@@ -0,0 +1,105 @@ | |||
/** | |||
* \file imperative/src/impl/subgraph.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megbrain/imperative/subgraph.h" | |||
namespace mgb { | |||
namespace imperative { | |||
void Subgraph::remove_unused_exprs() { | |||
std::unordered_set<size_t> required_vars = {outputs.begin(), outputs.end()}; | |||
required_vars.erase(0); | |||
for (auto iter = exprs.rbegin(); iter != exprs.rend(); ++iter) { | |||
auto& expr = *iter; | |||
bool required = false; | |||
for (auto output : expr.outputs) { | |||
if (required_vars.count(output)) { | |||
required = true; | |||
break; | |||
} | |||
} | |||
if (required) { | |||
required_vars.insert(expr.inputs.begin(), expr.inputs.end()); | |||
} else { | |||
expr.op = nullptr; | |||
} | |||
} | |||
exprs.erase(std::remove_if(exprs.begin(), exprs.end(), | |||
[](auto expr) { return expr.op == nullptr; }), | |||
exprs.end()); | |||
} | |||
SmallVector<bool> Subgraph::gen_input_mask() { | |||
std::unordered_set<size_t> unused_inputs = {inputs.begin(), inputs.end()}; | |||
for (auto&& expr : exprs) { | |||
for (auto&& input : expr.inputs) { | |||
unused_inputs.erase(input); | |||
} | |||
} | |||
for (auto&& output : outputs) { | |||
unused_inputs.erase(output); | |||
} | |||
unused_inputs.insert(0); | |||
SmallVector<bool> mask(inputs.size(), true); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (unused_inputs.count(inputs[i])) { | |||
mask[i] = false; | |||
} | |||
} | |||
return mask; | |||
} | |||
SmallVector<bool> Subgraph::gen_output_mask() { | |||
std::unordered_set<size_t> invalid_outputs = {outputs.begin(), | |||
outputs.end()}; | |||
for (auto&& input : inputs) { | |||
invalid_outputs.erase(input); | |||
} | |||
for (auto&& expr : exprs) { | |||
for (auto&& output : expr.outputs) { | |||
invalid_outputs.erase(output); | |||
} | |||
} | |||
for (auto&& constant: constants) { | |||
invalid_outputs.erase(constant.first); | |||
} | |||
invalid_outputs.insert(0); | |||
SmallVector<bool> mask(outputs.size(), true); | |||
for (size_t i = 0; i < outputs.size(); ++i) { | |||
if (invalid_outputs.count(outputs[i])) { | |||
mask[i] = false; | |||
} | |||
} | |||
return mask; | |||
} | |||
void Subgraph::replace_vars( | |||
const std::unordered_map<size_t, size_t>& replace_map) { | |||
// FIXME: preprocess replace_map | |||
auto replace_var = [&](var_t& var) { | |||
// TODO: detect infinite loop | |||
while (replace_map.count(var)) { | |||
var = replace_map.at(var); | |||
} | |||
}; | |||
for (auto& expr : exprs) { | |||
for (auto& input : expr.inputs) { | |||
replace_var(input); | |||
} | |||
} | |||
for (auto& output : outputs) { | |||
replace_var(output); | |||
} | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -14,6 +14,7 @@ | |||
#include "megbrain/graph.h" | |||
#include "megbrain/imperative/physical_tensor.h" | |||
#include "megbrain/imperative/utils/to_string.h" | |||
#include "megbrain/imperative/subgraph.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -28,54 +29,6 @@ enum DispatchMode { | |||
using SharedOp = std::shared_ptr<OpDef>; | |||
template <typename T> | |||
struct Expr { | |||
std::shared_ptr<OpDef> op; | |||
SmallVector<T> inputs; | |||
SmallVector<T> outputs; | |||
}; | |||
struct Subgraph { | |||
SmallVector<size_t> inputs; | |||
SmallVector<std::pair<size_t, TensorPtr>> constants; | |||
SmallVector<size_t> outputs; | |||
SmallVector<Expr<size_t>> exprs; | |||
template <typename T, typename F, typename C> | |||
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||
std::unordered_map<size_t, T> idx2var; | |||
mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
idx2var[inputs[i]] = input_vars[i]; | |||
} | |||
for (auto&& [idx, val]: constants) { | |||
idx2var[idx] = c(val); | |||
} | |||
for (auto& expr: exprs) { | |||
SmallVector<T> expr_inputs; | |||
for (auto idx: expr.inputs) { | |||
expr_inputs.push_back(idx2var[idx]); | |||
} | |||
SmallVector<T> expr_outputs = f(expr.op, std::move(expr_inputs)); | |||
mgb_assert(expr_outputs.size() == expr.outputs.size(), "output size mismatch"); | |||
for (size_t i = 0; i < expr_outputs.size(); ++i) { | |||
idx2var[expr.outputs[i]] = expr_outputs[i]; | |||
} | |||
} | |||
SmallVector<T> output_vars; | |||
for (auto idx: outputs) { | |||
output_vars.push_back(idx2var[idx]); | |||
} | |||
return output_vars; | |||
} | |||
bool empty() const { | |||
return outputs.size() == 0; | |||
} | |||
std::string repr() const; | |||
}; | |||
struct BackwardGraphResult { | |||
Subgraph backward; | |||
SmallVector<bool> save_for_backward; | |||
@@ -0,0 +1,100 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/subgraph.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include <list> | |||
#include "megbrain/imperative/physical_tensor.h" | |||
#include "megbrain/imperative/utils/to_string.h" | |||
#include "megbrain/utils/small_vector.h" | |||
namespace mgb { | |||
namespace imperative { | |||
class OpDef; | |||
template <typename T> | |||
struct Expr { | |||
std::shared_ptr<OpDef> op; | |||
SmallVector<T> inputs; | |||
SmallVector<T> outputs; | |||
}; | |||
template <typename T> | |||
struct ToStringTrait<Expr<T>> { | |||
std::string operator()(const Expr<T>& expr) { | |||
return ssprintf("%s = %s %s\n", to_string(expr.inputs).c_str(), to_string(expr.op.get()).c_str(), to_string(expr.outputs).c_str()); | |||
} | |||
}; | |||
struct Subgraph { | |||
template <typename TDesc> | |||
class Builder; | |||
using var_t = size_t; | |||
using vars_t = SmallVector<size_t>; | |||
using op_t = std::shared_ptr<OpDef>; | |||
using expr_t = Expr<var_t>; | |||
template <typename TDesc> | |||
using builder_t = Builder<TDesc>; | |||
SmallVector<var_t> inputs; | |||
SmallVector<std::pair<var_t, TensorPtr>> constants; | |||
SmallVector<var_t> outputs; | |||
SmallVector<expr_t> exprs; | |||
template <typename T, typename F, typename C> | |||
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||
std::unordered_map<size_t, T> idx2var; | |||
mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
idx2var[inputs[i]] = input_vars[i]; | |||
} | |||
for (auto&& [idx, val] : constants) { | |||
idx2var[idx] = c(val); | |||
} | |||
for (auto& expr : exprs) { | |||
SmallVector<T> expr_inputs; | |||
for (auto idx : expr.inputs) { | |||
expr_inputs.push_back(idx2var[idx]); | |||
} | |||
SmallVector<T> expr_outputs = | |||
f(expr.op, std::move(expr_inputs), expr.outputs.size()); | |||
mgb_assert(expr_outputs.size() == expr.outputs.size(), | |||
"output size mismatch"); | |||
for (size_t i = 0; i < expr_outputs.size(); ++i) { | |||
idx2var[expr.outputs[i]] = expr_outputs[i]; | |||
} | |||
} | |||
SmallVector<T> output_vars; | |||
for (auto idx : outputs) { | |||
output_vars.push_back(idx2var[idx]); | |||
} | |||
return output_vars; | |||
} | |||
void remove_unused_exprs(); | |||
SmallVector<bool> gen_input_mask(); | |||
SmallVector<bool> gen_output_mask(); | |||
bool empty() const { return outputs.size() == 0; } | |||
void replace_vars(const std::unordered_map<size_t, size_t>& replace_map); | |||
std::string repr() const; | |||
bool is_single() const; | |||
std::shared_ptr<OpDef> as_single() const; | |||
bool operator==(const Subgraph& rhs) const; | |||
}; | |||
} // namespace imperative | |||
} // namespace mgb |