GitOrigin-RevId: 595392ec89
release-1.6
@@ -0,0 +1,547 @@ | |||
/** | |||
* \file src/gopt/impl/dynamic_programming_solver.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include <queue> | |||
#include "./utils.h" | |||
#include "megbrain/gopt/global_layout_transform.h" | |||
using namespace mgb; | |||
using namespace gopt; | |||
using namespace cg; | |||
/* ================= DynamicProgrammingSolver::Impl ==================*/ | |||
class DynamicProgrammingSolver::Impl { | |||
public: | |||
Impl(size_t max_states) : m_max_states{max_states} {} | |||
~Impl() = default; | |||
Solution solve(const ProfilerBase* profiler, const Problem& problem); | |||
private: | |||
using TensorFormatsBitSet = uint32_t; | |||
using State = SmallVector<TensorFormatsBitSet>; | |||
static constexpr uint32_t MAX_TENSOR_FORMATS = sizeof(TensorFormatsBitSet); | |||
TensorFormatsBitSet add(TensorFormatsBitSet& set, TensorFormats fmt) { | |||
mgb_assert(static_cast<uint32_t>(fmt) < MAX_TENSOR_FORMATS); | |||
set |= (1 << static_cast<uint32_t>(fmt)); | |||
return set; | |||
} | |||
bool valid(const TensorFormatsBitSet& set, TensorFormats fmt) { | |||
mgb_assert(static_cast<uint32_t>(fmt) < MAX_TENSOR_FORMATS); | |||
bool val = set & (1 << static_cast<uint32_t>(fmt)); | |||
return val; | |||
} | |||
struct Value { | |||
OperatorNodeBase* opr; | |||
const State* prev; | |||
OprFormat opr_fmt; | |||
float time; | |||
///! index in the topo order of the correspoding operator | |||
size_t opr_idx; | |||
}; | |||
struct StateHash { | |||
size_t operator()(const State& key) const { | |||
size_t h = 0; | |||
for (auto&& v : key) { | |||
h = mgb::hash_pair_combine(h, | |||
std::hash<TensorFormatsBitSet>{}(v)); | |||
} | |||
return h; | |||
} | |||
}; | |||
struct StateEqual { | |||
size_t operator()(const State& lhs, const State& rhs) const { | |||
if (lhs.size() != rhs.size()) | |||
return false; | |||
for (size_t i = 0; i < lhs.size(); ++i) { | |||
if (lhs[i] != rhs[i]) | |||
return false; | |||
} | |||
return true; | |||
} | |||
}; | |||
using StateTable = std::unordered_map<State, Value, StateHash, StateEqual>; | |||
struct Cut { | |||
StateTable states; | |||
}; | |||
using ProfilingResult = ProfilerBase::ProfilingResult; | |||
using OprConfigTrait = LayoutTransformContext::OprConfigTrait; | |||
struct Context { | |||
const std::vector<OperatorNodeBase*>& topo; | |||
const ProfilingResult& rst; | |||
const OprConfigTrait& opr_configs; | |||
const SmallVector<TensorFormats>& available_tensor_formats; | |||
}; | |||
/*! | |||
* \brief get the tensor formats configuration for the operator with particular op format | |||
* \param[out] var2fmts hashmap that maps varnode to actual tensor formats of the op format configuration | |||
* \param[in] opr given operator | |||
* \param[in] opr_fmt given op format, an enum type argument which indicates the op format configuration. | |||
* \param[in] ctx context | |||
*/ | |||
TensorFormats get_io_formats(ThinHashMap<VarNode*, TensorFormats>& var2fmts, | |||
const OperatorNodeBase* opr, OprFormat opr_fmt, | |||
const Context& ctx); | |||
/*! | |||
* \brief compute the distace of two states of the given varnode | |||
* \param[in] from the source state | |||
* \param[in] to the target state | |||
* \param[in] var given varnode | |||
* \param[in] ctx context | |||
*/ | |||
float distance(const TensorFormatsBitSet& from, | |||
const TensorFormatsBitSet& to, VarNode* var, | |||
const Context& ctx); | |||
/*! | |||
* \brief compute the distace of two states of the given cut edges | |||
* \param[in] from the source state | |||
* \param[in] to the target state | |||
* \param[in] edge a VarNodeArry, the given cut edges | |||
* \param[in] ctx context | |||
*/ | |||
float state_distance(const State& from, const State& to, | |||
const VarNodeArray& edge, const Context& ctx); | |||
/*! | |||
* \brief analyze the edges of each cut | |||
* \param[out] edges the return edges of the cuts | |||
* \param[out] edge2idx hashmaps, that maps edge(varnode) to its index | |||
* \param[in] ctx contex | |||
*/ | |||
void analyze_edges(SmallVector<VarNodeArray>& edges, | |||
SmallVector<std::unordered_map<VarNode*, int>>& edge2idx, | |||
const Context& ctx); | |||
/*! | |||
* \brief prune states using the distance of states | |||
*/ | |||
void prune(StateTable& states, const VarNodeArray& edge, | |||
const Context& ctx); | |||
/*! | |||
* \brief force prune states, reserve the smallest MAX_STATES states | |||
*/ | |||
void force_prune(StateTable& states); | |||
private: | |||
size_t m_max_states; | |||
}; | |||
TensorFormats DynamicProgrammingSolver::Impl::get_io_formats( | |||
ThinHashMap<VarNode*, TensorFormats>& var2fmts, | |||
const OperatorNodeBase* opr, OprFormat opr_fmt, const Context& ctx) { | |||
auto&& rst = ctx.rst; | |||
auto&& opr_configs = ctx.opr_configs; | |||
auto iter = opr_configs.find(opr->dyn_typeinfo()); | |||
Maybe<OprTensorFormatsConfiguration> fmtcfg = None; | |||
if (iter != opr_configs.end()) { | |||
fmtcfg = (*iter->second.at(opr_fmt))(opr); | |||
} | |||
TensorFormats out_fmt; | |||
if (fmtcfg.valid()) | |||
out_fmt = fmtcfg.val().output_tensor_formats[0]; | |||
else | |||
out_fmt = opr_format_to_tensor_formats(opr_fmt); | |||
for (size_t i = 0; i < opr->input().size(); ++i) { | |||
auto&& var = opr->input(i); | |||
auto iter = rst.var_record.find(var); | |||
if (iter != rst.var_record.end()) { | |||
if (fmtcfg.valid()) | |||
var2fmts[var] = fmtcfg.val().input_tensor_formats[i]; | |||
else | |||
var2fmts[var] = opr_format_to_tensor_formats(opr_fmt); | |||
} | |||
} | |||
return out_fmt; | |||
} | |||
float DynamicProgrammingSolver::Impl::distance(const TensorFormatsBitSet& from, | |||
const TensorFormatsBitSet& to, | |||
VarNode* var, | |||
const Context& ctx) { | |||
auto&& costs = ctx.rst.var_record.at(var).costs; | |||
auto&& available_tensor_formats = ctx.available_tensor_formats; | |||
float dist = 0.f; | |||
if ((from & to) == to) | |||
return dist; | |||
auto to_set = ((from | to) ^ from); | |||
for (auto o : available_tensor_formats) { | |||
if (valid(to_set, o)) { | |||
float o_cost = std::numeric_limits<float>::max(); | |||
for (auto i : available_tensor_formats) { | |||
if (valid(from, i)) { | |||
float cost = costs.at({i, o}); | |||
o_cost = std::min(o_cost, cost); | |||
} | |||
} | |||
dist += o_cost; | |||
} | |||
} | |||
return dist; | |||
} | |||
float DynamicProgrammingSolver::Impl::state_distance(const State& from, | |||
const State& to, | |||
const VarNodeArray& edge, | |||
const Context& ctx) { | |||
float dist = 0.f; | |||
mgb_assert(from.size() == to.size() && from.size() == edge.size()); | |||
for (size_t i = 0; i < edge.size(); ++i) { | |||
dist += distance(from[i], to[i], edge[i], ctx); | |||
} | |||
return dist; | |||
} | |||
void DynamicProgrammingSolver::Impl::analyze_edges( | |||
SmallVector<VarNodeArray>& edges, | |||
SmallVector<std::unordered_map<VarNode*, int>>& edge2idx, | |||
const Context& ctx) { | |||
auto&& topo = ctx.topo; | |||
auto&& rst = ctx.rst; | |||
size_t nr_oprs = topo.size(); | |||
edges.resize(nr_oprs); | |||
edge2idx.resize(nr_oprs); | |||
ThinHashSet<VarNode*> cur_edge; | |||
size_t cur = nr_oprs - 1; | |||
int idx = 0; | |||
for (auto&& ov : topo[cur]->usable_output()) { | |||
edges[cur].push_back(ov); | |||
edge2idx[cur].emplace(ov, idx++); | |||
} | |||
cur--; | |||
for (const auto& opr : reverse_adaptor(topo)) { | |||
for (const auto& i : opr->input()) { | |||
if (rst.var_record.count(i) > 0) { | |||
cur_edge.insert(i); | |||
} | |||
} | |||
for (auto&& ov : opr->usable_output()) { | |||
cur_edge.erase(ov); | |||
} | |||
edges[cur].insert(edges[cur].begin(), cur_edge.begin(), cur_edge.end()); | |||
int i = 0; | |||
for (auto&& e : edges[cur]) { | |||
edge2idx[cur][e] = i++; | |||
} | |||
if (cur == 0) | |||
break; | |||
cur--; | |||
} | |||
} | |||
void DynamicProgrammingSolver::Impl::prune(StateTable& states, | |||
const VarNodeArray& edge, | |||
const Context& ctx) { | |||
struct Item { | |||
decltype(states.begin()) iter; | |||
}; | |||
std::list<Item> list; | |||
for (auto it = states.begin(); it != states.end(); ++it) { | |||
list.emplace_back(Item{it}); | |||
} | |||
SmallVector<State> removed_states; | |||
for (auto i = list.begin(); i != list.end();) { | |||
bool advanced_i = false; | |||
for (auto j = std::next(i, 1); j != list.end();) { | |||
if (i->iter->second.time > j->iter->second.time && | |||
state_distance(j->iter->first, i->iter->first, edge, ctx) < | |||
i->iter->second.time - j->iter->second.time) { | |||
removed_states.push_back(i->iter->first); | |||
i = list.erase(i); | |||
advanced_i = true; | |||
break; | |||
} else if (i->iter->second.time < j->iter->second.time && | |||
state_distance(i->iter->first, j->iter->first, edge, | |||
ctx) < | |||
j->iter->second.time - i->iter->second.time) { | |||
removed_states.push_back(j->iter->first); | |||
j = list.erase(j); | |||
} else { | |||
j = std::next(j, 1); | |||
} | |||
} | |||
if (!advanced_i) | |||
i = std::next(i, 1); | |||
} | |||
for (auto&& state : removed_states) | |||
states.erase(state); | |||
} | |||
void DynamicProgrammingSolver::Impl::force_prune(StateTable& states) { | |||
if (states.size() < m_max_states) | |||
return; | |||
struct Item { | |||
decltype(states.begin()) iter; | |||
}; | |||
auto cmp = [](Item lhs, Item rhs) { | |||
return lhs.iter->second.time < rhs.iter->second.time; | |||
}; | |||
std::priority_queue<Item, std::vector<Item>, decltype(cmp)> pq(cmp); | |||
for (auto it = states.begin(); it != states.end(); ++it) { | |||
if (pq.size() < m_max_states) | |||
pq.push(Item{it}); | |||
else { | |||
auto i = pq.top(); | |||
if (it->second.time < i.iter->second.time) { | |||
pq.pop(); | |||
pq.push(Item{it}); | |||
} | |||
} | |||
} | |||
StateTable active_state; | |||
while (!pq.empty()) { | |||
auto i = pq.top(); | |||
active_state.insert(*i.iter); | |||
pq.pop(); | |||
} | |||
states.swap(active_state); | |||
} | |||
DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( | |||
const ProfilerBase* profiler, const Problem& problem) { | |||
const auto rst = profiler->profile(problem); | |||
const auto& partition = problem.graph_partition(); | |||
const auto& opr_configs = problem.opr_configs(); | |||
const auto& base_fmt = problem.base_format(); | |||
const auto& available_tensor_formats = problem.available_tensor_formats(); | |||
const auto& topo = partition.all_oprs(); | |||
Context ctx{topo, rst, opr_configs, available_tensor_formats}; | |||
SmallVector<VarNodeArray> edges; | |||
SmallVector<std::unordered_map<VarNode*, int>> edge2idx; | |||
/// analyze edges of each cuts | |||
analyze_edges(edges, edge2idx, ctx); | |||
SmallVector<Cut> cuts; | |||
size_t cur = 0; | |||
/// initialize states | |||
auto init = [&, this](OperatorNodeBase* opr) { | |||
auto it = rst.opr_record.find(opr); | |||
if (it == rst.opr_record.end()) | |||
return; | |||
ThinHashSet<VarNode*> ovar_set; | |||
for (auto&& ov : opr->usable_output()) { | |||
ovar_set.insert(ov); | |||
} | |||
const auto& records = it->second.costs; | |||
cuts.emplace_back(Cut{}); | |||
auto& states = cuts.back().states; | |||
for (const auto& record : records) { | |||
auto opr_fmt = record.first; | |||
float opr_time = record.second; | |||
ThinHashMap<VarNode*, TensorFormats> ivar2fmts; | |||
auto out_fmt = get_io_formats(ivar2fmts, opr, opr_fmt, ctx); | |||
const auto& edge = edges[cur]; | |||
State state(edge.size(), 0); | |||
Value value{opr, nullptr, opr_fmt, 0.f, cur}; | |||
float ovar_time = 0.f; | |||
for (size_t i = 0; i < edge.size(); ++i) { | |||
auto&& var = edge[i]; | |||
auto&& costs = rst.var_record.at(var).costs; | |||
if (ovar_set.count(var) > 0) { | |||
add(state[i], out_fmt); | |||
if (partition.output().count(var) > 0 && | |||
out_fmt != base_fmt) { | |||
ovar_time += costs.at({out_fmt, base_fmt}); | |||
add(state[i], base_fmt); | |||
} | |||
} else { | |||
add(state[i], base_fmt); | |||
} | |||
} | |||
float ivar_time = 0.f; | |||
for (const auto& kv : ivar2fmts) { | |||
auto&& v = kv.first; | |||
auto&& costs = rst.var_record.at(v).costs; | |||
auto to = kv.second; | |||
float min_time = std::numeric_limits<float>::max(); | |||
if (base_fmt == to) { | |||
min_time = 0.f; | |||
} else { | |||
min_time = costs.at({base_fmt, to}); | |||
if (edge2idx[cur].count(v) > 0) { | |||
add(state[edge2idx[cur][v]], to); | |||
} | |||
} | |||
ivar_time += min_time; | |||
} | |||
value.time = opr_time + ivar_time + ovar_time; | |||
states[state] = value; | |||
} | |||
}; | |||
/// update the states | |||
auto body = [&, this](OperatorNodeBase* opr) { | |||
auto it = rst.opr_record.find(opr); | |||
if (it == rst.opr_record.end()) | |||
return; | |||
ThinHashSet<VarNode*> ovar_set; | |||
for (auto&& ov : opr->usable_output()) { | |||
ovar_set.insert(ov); | |||
} | |||
const auto& records = it->second.costs; | |||
StateTable states; | |||
for (const auto& record : records) { | |||
auto opr_fmt = record.first; | |||
float opr_time = record.second; | |||
ThinHashMap<VarNode*, TensorFormats> ivar2fmts; | |||
auto out_fmt = get_io_formats(ivar2fmts, opr, opr_fmt, ctx); | |||
for (const auto& kv : cuts.back().states) { | |||
auto&& prev_state = kv.first; | |||
float prev_time = kv.second.time; | |||
const auto& edge = edges[cur]; | |||
State state(edge.size(), 0); | |||
Value value{opr, &prev_state, opr_fmt, 0.f, cur}; | |||
float ovar_time = 0.f; | |||
for (size_t i = 0; i < edge.size(); ++i) { | |||
auto&& var = edge[i]; | |||
auto&& costs = rst.var_record.at(var).costs; | |||
auto iter = edge2idx[cur - 1].find(var); | |||
if (iter != edge2idx[cur - 1].end()) { | |||
state[i] = prev_state[iter->second]; | |||
} else { | |||
mgb_assert(ovar_set.count(var) > 0); | |||
add(state[i], out_fmt); | |||
if (partition.output().count(var) > 0 && | |||
out_fmt != base_fmt) { | |||
ovar_time += costs.at({out_fmt, base_fmt}); | |||
add(state[i], base_fmt); | |||
} | |||
} | |||
} | |||
float ivar_time = 0.f; | |||
for (const auto& kv : ivar2fmts) { | |||
auto&& v = kv.first; | |||
auto&& costs = rst.var_record.at(v).costs; | |||
auto to = kv.second; | |||
auto it1 = edge2idx[cur - 1].find(v); | |||
float min_time = std::numeric_limits<float>::max(); | |||
if (valid(prev_state[it1->second], to)) { | |||
min_time = 0.f; | |||
} else { | |||
for (auto&& from : available_tensor_formats) { | |||
if (valid(prev_state[it1->second], from)) { | |||
float cost = costs.at({from, to}); | |||
min_time = std::min(min_time, cost); | |||
} | |||
} | |||
} | |||
auto it2 = edge2idx[cur].find(v); | |||
if (it2 != edge2idx[cur].end()) { | |||
add(state[it2->second], to); | |||
} | |||
ivar_time += min_time; | |||
} | |||
value.time = prev_time + opr_time + ivar_time + ovar_time; | |||
auto iter = states.find(state); | |||
if (iter == states.end()) { | |||
states[state] = value; | |||
} else { | |||
float time = iter->second.time; | |||
if (value.time < time) { | |||
iter->second = value; | |||
} | |||
} | |||
} | |||
} | |||
cuts.emplace_back(Cut{}); | |||
cuts.back().states.swap(states); | |||
}; | |||
/// forward pass to generate all states | |||
for (auto&& opr : topo) { | |||
if (cuts.empty()) { | |||
init(opr); | |||
} else { | |||
body(opr); | |||
} | |||
if (!cuts.empty()) { | |||
auto& states = cuts.back().states; | |||
prune(states, edges[cur], ctx); | |||
force_prune(states); | |||
} | |||
cur++; | |||
} | |||
Solution solution; | |||
/// backward pass to generate the solution | |||
float min_time = std::numeric_limits<float>::max(); | |||
OperatorNodeBase* cur_opr; | |||
OprFormat min_fmt; | |||
const State* pstate = nullptr; | |||
for (auto&& kv : cuts.back().states) { | |||
auto&& v = kv.second; | |||
if (v.time < min_time) { | |||
cur_opr = v.opr; | |||
pstate = v.prev; | |||
min_time = v.time; | |||
min_fmt = v.opr_fmt; | |||
///! just to check the tensor formats of the output varnode | |||
auto&& k = kv.first; | |||
size_t opr_idx = v.opr_idx; | |||
for (size_t i = 0; i < k.size(); ++i) { | |||
auto&& fmt_set = k[i]; | |||
auto&& var = edges[opr_idx][i]; | |||
if (partition.output().count(var)) { | |||
mgb_assert(valid(fmt_set, base_fmt)); | |||
} | |||
} | |||
} | |||
} | |||
mgb_log_debug("opr:%s;format:%s;time:%f", cur_opr->cname(), | |||
opr_format_to_string(min_fmt), min_time); | |||
solution.insert({cur_opr, min_fmt}); | |||
cur = cuts.size() - 2; | |||
while (pstate) { | |||
auto val = cuts[cur].states[*pstate]; | |||
///! just to check the tensor formats of the output varnode | |||
size_t opr_idx = val.opr_idx; | |||
for (size_t i = 0; i < pstate->size(); ++i) { | |||
auto&& fmt_set = pstate->operator[](i); | |||
auto&& var = edges[opr_idx][i]; | |||
if (partition.output().count(var)) { | |||
mgb_assert(valid(fmt_set, base_fmt)); | |||
} | |||
} | |||
mgb_log_debug("opr:%s;format:%s;time:%f", val.opr->cname(), | |||
opr_format_to_string(val.opr_fmt), val.time); | |||
solution.insert({val.opr, val.opr_fmt}); | |||
pstate = val.prev; | |||
cur--; | |||
} | |||
return solution; | |||
} | |||
/* =================== DynamicProgrammingSolver ======================*/ | |||
DynamicProgrammingSolver::Solution DynamicProgrammingSolver::do_solve( | |||
const Problem& problem) const { | |||
constexpr size_t MAX_STATES = 1024; | |||
Impl impl(MAX_STATES); | |||
return impl.solve(m_profiler.get(), problem); | |||
} | |||
bool DynamicProgrammingSolver::can_solve(const Problem& problem) const { | |||
auto&& available_tensor_formats = problem.available_tensor_formats(); | |||
for (auto&& tensor_format : available_tensor_formats) { | |||
if (static_cast<uint32_t>(tensor_format) >= 32) | |||
return false; | |||
} | |||
return true; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,40 @@ | |||
/** | |||
* \file src/gopt/impl/layout_transform_context.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "./utils.h" | |||
#include "megbrain/gopt/global_layout_transform.h" | |||
using namespace mgb; | |||
using namespace gopt; | |||
/* ================= LayoutTransformContext ==================*/ | |||
LayoutTransformContext& LayoutTransformContext::add_opr_config( | |||
Typeinfo* opr, OprFormat opr_format) { | |||
auto& dispatchers = m_opr_configs[opr]; | |||
dispatchers[opr_format] = | |||
OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | |||
opr, opr_format); | |||
return *this; | |||
} | |||
LayoutTransformContext& LayoutTransformContext::add_opr_config( | |||
Typeinfo* opr, SmallVector<OprFormat> opr_formats) { | |||
auto& dispatchers = m_opr_configs[opr]; | |||
for (auto opr_fmt : opr_formats) { | |||
dispatchers[opr_fmt] = | |||
OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | |||
opr, opr_fmt); | |||
} | |||
return *this; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -17,6 +17,7 @@ | |||
#include "megbrain/graph/event.h" | |||
#include "megbrain/opr/dnn/pooling.h" | |||
#include "megbrain/opr/imgproc.h" | |||
#include "megbrain/opr/nn_int.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/plugin/base.h" | |||
#include "megbrain/serialization/sereg.h" | |||
@@ -265,6 +266,10 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( | |||
record.opr = opr; | |||
auto& costs = record.costs; | |||
for (auto&& i : available_configs) { | |||
/// XXXX remove later | |||
if (i.opr_format == OprFormat::NCHW && | |||
opr->input(0)->dtype().enumv() != DTypeEnum::Float32) | |||
continue; | |||
costs[i.opr_format] = profile_operator(opr, base_config, i); | |||
} | |||
return record; | |||
@@ -414,37 +419,42 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( | |||
cb(Resize, 1), | |||
#undef cb | |||
}; | |||
static const ThinHashSet<Typeinfo*> skip_opr_types = { | |||
TypeCvt::typeinfo(), Elemwise::typeinfo(), | |||
ElemwiseMultiType::typeinfo()}; | |||
ThinHashSet<VarNode*> vars; | |||
ThinHashSet<OperatorNodeBase*> oprs; | |||
{ | |||
auto cb = [&cvprop, &vars, &oprs](OperatorNodeBase* opr) { | |||
if (cvprop.is_const(opr)) | |||
return; | |||
oprs.insert(opr); | |||
auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); | |||
if (find == format_aware_input_tensors.end()) { | |||
for (auto&& i : opr->input()) { | |||
if (!cvprop.is_const(i)) { | |||
vars.insert(i); | |||
} | |||
ThinHashSet<OperatorNodeBase*> skip_oprs; | |||
for (auto&& opr : problem.graph_partition().all_oprs()) { | |||
if (cvprop.is_const(opr)) | |||
continue; | |||
bool skip = true; | |||
for (auto&& i : opr->input()) { | |||
skip &= problem.graph_partition().input().count(i) > 0 || | |||
skip_oprs.count(i->owner_opr()) > 0; | |||
} | |||
skip &= skip_opr_types.count(opr->dyn_typeinfo()); | |||
if (skip) | |||
skip_oprs.insert(opr); | |||
oprs.insert(opr); | |||
auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); | |||
if (find == format_aware_input_tensors.end()) { | |||
for (auto&& i : opr->input()) { | |||
if (!cvprop.is_const(i)) { | |||
vars.insert(i); | |||
} | |||
} else { | |||
size_t nr_input_tensor = | |||
std::min(find->second, opr->input().size()); | |||
for (size_t i = 0; i < nr_input_tensor; ++i) { | |||
if (!cvprop.is_const(opr->input(i))) { | |||
vars.insert(opr->input(i)); | |||
} | |||
} | |||
} else { | |||
size_t nr_input_tensor = | |||
std::min(find->second, opr->input().size()); | |||
for (size_t i = 0; i < nr_input_tensor; ++i) { | |||
if (!cvprop.is_const(opr->input(i))) { | |||
vars.insert(opr->input(i)); | |||
} | |||
} | |||
vars.insert(opr->output(0)); | |||
}; | |||
DepOprIter iter{cb}; | |||
for (auto&& i : problem.graph_partition().input()) { | |||
iter.set_visited(i->owner_opr()); | |||
} | |||
for (auto&& o : problem.graph_partition().output()) { | |||
iter.add(o->owner_opr()); | |||
for (auto&& ov : opr->usable_output()) { | |||
vars.insert(ov); | |||
} | |||
} | |||
@@ -462,8 +472,14 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( | |||
auto&& opr_configs = problem.opr_configs(); | |||
auto find = opr_configs.find(opr->dyn_typeinfo()); | |||
if (find == opr_configs.end()) { | |||
opr_record[opr] = profile_operator(opr, base_format, | |||
available_tensor_formats); | |||
if (skip_oprs.count(opr) > 0) { | |||
SmallVector<TensorFormats> tensor_formats = {base_format}; | |||
opr_record[opr] = | |||
profile_operator(opr, base_format, tensor_formats); | |||
} else { | |||
opr_record[opr] = profile_operator(opr, base_format, | |||
available_tensor_formats); | |||
} | |||
} else { | |||
auto&& dispatchers = find->second; | |||
SmallVector<OprTensorFormatsConfiguration> configs; | |||
@@ -0,0 +1,56 @@ | |||
/** | |||
* \file src/gopt/impl/profiling_based_solver.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megbrain/gopt/global_layout_transform.h" | |||
#include "megbrain/opr/dnn/pooling.h" | |||
#include "megbrain/opr/imgproc.h" | |||
using namespace mgb; | |||
using namespace gopt; | |||
using namespace opr; | |||
/* =================== ProfilingBasedSolverSolver ======================*/ | |||
ProfilingBasedSolver::ProfilingBasedSolver( | |||
std::unique_ptr<ProfilerBase> profiler) | |||
: m_profiler{std::move(profiler)} { | |||
static const ThinHashSet<Typeinfo*> format_aware_oprs = { | |||
#define cb(_Opr) _Opr::typeinfo() | |||
cb(Convolution), | |||
cb(ConvBiasForward), | |||
cb(ConvolutionBackwardData), | |||
cb(PoolingForward), | |||
cb(WarpPerspective), | |||
cb(Resize), | |||
}; | |||
m_graph_partition_filter = [](const GraphPartition& partition) { | |||
bool has_format_aware_opr = false; | |||
for (auto&& opr : partition.all_oprs()) { | |||
if (!has_format_aware_opr && | |||
format_aware_oprs.count(opr->dyn_typeinfo())) { | |||
has_format_aware_opr = true; | |||
break; | |||
} | |||
} | |||
return has_format_aware_opr; | |||
}; | |||
} | |||
ProfilingBasedSolver::Solution ProfilingBasedSolver::solve( | |||
const Problem& problem) const { | |||
const auto& partition = problem.graph_partition(); | |||
if (!m_graph_partition_filter(partition)) | |||
return Solution{}; | |||
return do_solve(problem); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -11,9 +11,9 @@ | |||
*/ | |||
#include "megbrain/gopt/reformat_manager.h" | |||
#include "./utils.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/utils/arith_helper.h" | |||
#include "./utils.h" | |||
using namespace mgb; | |||
using namespace gopt; | |||
@@ -87,21 +87,6 @@ bool ReformatManager::ReformatKey::Equal::operator()( | |||
lhs.attribute == rhs.attribute; | |||
} | |||
ReformatManager::ReformatKey& | |||
ReformatManager::ReformatKey::deduce_reformat_dtype_enum(const DType& dt) { | |||
static const ThinHashSet<std::pair<TensorFormats, TensorFormats>> set = { | |||
{TensorFormats::NCHW, TensorFormats::NCHWc64}, | |||
{TensorFormats::NCHWc64, TensorFormats::NCHW}, | |||
{TensorFormats::NCHW, TensorFormats::NHWC}, | |||
{TensorFormats::NHWC, TensorFormats::NCHW}}; | |||
if (set.count({input_format, output_format}) > 0 && | |||
(dt.enumv() == DTypeEnum::QuantizedS4 || | |||
dt.enumv() == DTypeEnum::Quantized4Asymm)) { | |||
input_dtype = output_dtype = dt.enumv(); | |||
} | |||
return *this; | |||
} | |||
// =================== ReformatManager ====================*/ | |||
ReformatManager::ReformatManager() { | |||
using Attribute = ReformatKey::Attribute; | |||
@@ -378,7 +363,7 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_featrue( | |||
divup(orig_channel, input_alignment) * input_alignment; | |||
size_t aligned_out_channel = | |||
divup(orig_channel, output_alignment) * output_alignment; | |||
size_t common_alignment = input_alignment * output_alignment / | |||
size_t common_alignment = input_alignment * output_alignment / | |||
gcd(input_alignment, output_alignment); | |||
size_t aligned_channel = | |||
divup(orig_channel, common_alignment) * common_alignment; | |||
@@ -427,11 +412,11 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||
for (size_t i = 0; i < input_shape.ndim; ++i) { | |||
if (input_shape[i].name() == Dimension::Name::C && | |||
input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { | |||
in_channels = orig_var->shape()[i]; | |||
in_channels = orig_var->shape()[i] * input_shape[i].stride(); | |||
input_channel_idx = i; | |||
mgb_assert(input_shape[i].stride() == 1, | |||
"unsupport weight format(got:%s)", | |||
input_shape.to_string().c_str()); | |||
// mgb_assert(input_shape[i].stride() == 1, | |||
// "unsupport weight format(got:%s)", | |||
// input_shape.to_string().c_str()); | |||
} else if ((input_shape[i].name() == Dimension::Name::K || | |||
input_shape[i].name() == Dimension::Name::N) && | |||
input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { | |||
@@ -536,7 +521,8 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, | |||
"formats(var:%s;shp:%s;fmt:%s)", | |||
var->cname(), oshp.to_string().c_str(), | |||
orig_shape.to_string().c_str()); | |||
if (oshp.is_scalar()) return oshp; | |||
if (oshp.is_scalar()) | |||
return oshp; | |||
TensorShape tshp; | |||
ThinHashMap<Dimension::Name, int> name2dominant; | |||
for (size_t i = 0; i < orig_shape.ndim; ++i) { | |||
@@ -597,4 +583,32 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, | |||
return tshp; | |||
} | |||
ReformatManager::AlignmentDesc mgb::gopt::make_aligned_desc( | |||
TensorFormats weight_format, TensorFormats out_feature_format) { | |||
using AlignmentDesc = ReformatManager::AlignmentDesc; | |||
using Name = Dimension::Name; | |||
auto weight_shape = tensor_formats_to_named_tensor_shape(weight_format); | |||
auto out_shape = tensor_formats_to_named_tensor_shape(out_feature_format); | |||
size_t out_channel_alignment = 1; | |||
for (size_t i = 0; i < out_shape.ndim; ++i) { | |||
auto name = out_shape[i].name(); | |||
auto extent = out_shape[i].extent(); | |||
if ((name == Name::C || name == Name::K) && | |||
extent == Dimension::UNDETERMINED_EXTENT) { | |||
out_channel_alignment = out_shape[i].stride(); | |||
break; | |||
} | |||
} | |||
Name out_channel_name; | |||
for (size_t i = 0; i < weight_shape.ndim; ++i) { | |||
auto name = weight_shape[i].name(); | |||
auto extent = weight_shape[i].extent(); | |||
if ((name == Name::N || name == Name::K) && | |||
extent == Dimension::UNDETERMINED_EXTENT) { | |||
out_channel_name = name; | |||
} | |||
} | |||
return AlignmentDesc{out_channel_name, out_channel_alignment}; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -304,10 +304,15 @@ std::vector<GraphPartition> SubGraphExtractor::extract( | |||
} | |||
} | |||
partition->opr_set().insert(opr); | |||
partition->all_oprs().push_back(opr); | |||
for (const auto& i : opr->input()) | |||
partition->input().insert(i); | |||
} | |||
} | |||
for (auto&& partition : partitions) { | |||
auto& all_oprs = partition.all_oprs(); | |||
std::reverse(all_oprs.begin(), all_oprs.end()); | |||
} | |||
return partitions; | |||
} | |||
@@ -36,6 +36,28 @@ static inline const char* opr_format_to_string( | |||
#undef cb | |||
} | |||
static inline TensorFormats opr_format_to_tensor_formats( | |||
OprTensorFormatsConfiguration::OprFormat opr_format) { | |||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||
switch (opr_format) { | |||
case OprFormat::NCHW: | |||
return TensorFormats::NCHW; | |||
case OprFormat::NHWC: | |||
return TensorFormats::NHWC; | |||
case OprFormat::NCHW4: | |||
return TensorFormats::NCHWc4; | |||
case OprFormat::NCHW32: | |||
return TensorFormats::NCHWc32; | |||
case OprFormat::NCHW64: | |||
return TensorFormats::NCHWc64; | |||
case OprFormat::CHWN4: | |||
return TensorFormats::CHWNc4; | |||
default: | |||
mgb_throw(AssertionError, "format(%s) is not supported", | |||
opr_format_to_string(opr_format)); | |||
}; | |||
} | |||
static inline megdnn::NamedTensorShape tensor_formats_to_named_tensor_shape( | |||
TensorFormats format) { | |||
switch (format) { | |||
@@ -11,6 +11,7 @@ | |||
*/ | |||
#pragma once | |||
#include "megbrain/gopt/framework.h" | |||
#include "megbrain/gopt/reformat_manager.h" | |||
#include "megbrain/gopt/subgraph_extractor.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
@@ -41,14 +42,16 @@ struct OprTensorFormatsConfiguration { | |||
/*! | |||
* \brief A structure that describes the global layout transform problem | |||
*/ | |||
class Problem { | |||
class LayoutTransformContext { | |||
public: | |||
using OprList = SubGraphExtractor::OprList; | |||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||
using OprTensorFormatsDispatcher = | |||
OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; | |||
using OprConfigTrait = | |||
ThinHashMap<Typeinfo*, | |||
ThinHashMap<OprFormat, OprTensorFormatsDispatcher*>>; | |||
using ReformatAttribute = ReformatManager::ReformatKey::Attribute; | |||
struct Attribute { | |||
OprFormat base_opr_format; /// the base opr format indicates that the | |||
/// network to be optimized is constructed | |||
@@ -62,58 +65,110 @@ public: | |||
/// (like elemwise, elemwise multi type, | |||
/// typecvt etc.) are built in the base | |||
/// tensor format. | |||
ReformatAttribute | |||
reformat_attribute; /// additional reformat attribute, which | |||
/// indicates whether to pad nhwc layout | |||
/// automatically or to enable nhwcd4 format | |||
/// on opencl platform to use image object | |||
}; | |||
Problem(const GraphPartition& graph_partition, | |||
const SmallVector<TensorFormats>& available_tensor_formats, | |||
const OprConfigTrait& opr_config, const Attribute& attribute) | |||
: m_graph_partition{graph_partition}, | |||
m_available_tensor_formats{available_tensor_formats}, | |||
m_opr_configs{opr_config}, | |||
LayoutTransformContext() = delete; | |||
LayoutTransformContext(OprList opr_list, | |||
SmallVector<TensorFormats> available_tensor_formats, | |||
Attribute attribute) | |||
: m_opr_list{std::move(opr_list)}, | |||
m_available_tensor_formats{std::move(available_tensor_formats)}, | |||
m_attribute{attribute} {} | |||
LayoutTransformContext(OprList opr_list, | |||
SmallVector<TensorFormats> available_tensor_formats, | |||
OprConfigTrait opr_configs, Attribute attribute) | |||
: m_opr_list{std::move(opr_list)}, | |||
m_available_tensor_formats{std::move(available_tensor_formats)}, | |||
m_opr_configs{std::move(opr_configs)}, | |||
m_attribute{attribute} {} | |||
const OprList& opr_list() const { return m_opr_list; } | |||
const SmallVector<TensorFormats>& available_tensor_formats() const { | |||
return m_available_tensor_formats; | |||
} | |||
const OprConfigTrait& opr_configs() const { return m_opr_configs; } | |||
Attribute attribute() const { return m_attribute; } | |||
/*! | |||
* \brief add an op format configuration for a particular operator type | |||
* \param opr runtime typeinfo of operator | |||
* \param opr_format op format configuration which to be enabled in the | |||
* layout transform problem | |||
*/ | |||
LayoutTransformContext& add_opr_config(Typeinfo* opr, OprFormat opr_format); | |||
/*! | |||
* \brief add a vector of op format configurations for a particular operator | |||
* type | |||
* \param opr runtime typeinfo of operator | |||
* \param opr_format op format configuration which to be enabled in the | |||
* layout transform problem | |||
*/ | |||
LayoutTransformContext& add_opr_config(Typeinfo* opr, | |||
SmallVector<OprFormat> opr_formats); | |||
private: | |||
OprList m_opr_list; /// supported operator list | |||
SmallVector<TensorFormats> | |||
m_available_tensor_formats; /// the available tensor formats, used | |||
/// for format agnostic operators (like | |||
/// elemwise, elemwise multi type, | |||
/// typecvt, etc. | |||
OprConfigTrait m_opr_configs; /// the available opr format configurations, | |||
/// used for format aware operators (like | |||
/// conv, deconv, conv_bias, etc. | |||
Attribute m_attribute; /// the extra attributes to describe the problem | |||
}; | |||
class Problem { | |||
public: | |||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||
using OprTensorFormatsDispatcher = | |||
OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; | |||
using OprConfigTrait = LayoutTransformContext::OprConfigTrait; | |||
using Attribute = LayoutTransformContext::Attribute; | |||
Problem(const GraphPartition& graph_partition, | |||
const LayoutTransformContext& ctx) | |||
: m_graph_partition{graph_partition}, m_ctx{ctx} {} | |||
~Problem() noexcept = default; | |||
const GraphPartition& graph_partition() const { return m_graph_partition; } | |||
const OprConfigTrait& opr_configs() const { return m_opr_configs; } | |||
const OprConfigTrait& opr_configs() const { return m_ctx.opr_configs(); } | |||
const SmallVector<TensorFormats>& available_tensor_formats() const { | |||
return m_available_tensor_formats; | |||
return m_ctx.available_tensor_formats(); | |||
} | |||
TensorFormats base_format() const { | |||
return m_attribute.base_tensor_formats; | |||
return m_ctx.attribute().base_tensor_formats; | |||
} | |||
/*! | |||
* \brief return the tensor formats configuration of an operator in the | |||
* default op format | |||
*/ | |||
OprTensorFormatsConfiguration base_config( | |||
const cg::OperatorNodeBase* opr) const { | |||
auto _ = OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | |||
opr->dyn_typeinfo(), m_attribute.base_opr_format); | |||
opr->dyn_typeinfo(), m_ctx.attribute().base_opr_format); | |||
auto rst = (*_)(opr); | |||
if (rst.valid()) | |||
return rst.val(); | |||
OprTensorFormatsConfiguration config; | |||
config.typeinfo = opr->dyn_typeinfo(); | |||
config.opr_format = m_attribute.base_opr_format; | |||
config.opr_format = m_ctx.attribute().base_opr_format; | |||
for (const auto& i : opr->input()) { | |||
config.input_dtypes.emplace_back(i->dtype().enumv()); | |||
config.input_tensor_formats.emplace_back( | |||
m_attribute.base_tensor_formats); | |||
config.input_tensor_formats.emplace_back(base_format()); | |||
config.input_tensor_types.emplace_back(TensorType::FEATURE); | |||
} | |||
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||
config.output_tensor_formats.emplace_back( | |||
m_attribute.base_tensor_formats); | |||
config.output_tensor_formats.emplace_back(base_format()); | |||
return config; | |||
} | |||
private: | |||
const GraphPartition& m_graph_partition; /// the graph partition | |||
const SmallVector<TensorFormats>& | |||
m_available_tensor_formats; /// the available tensor formats, used | |||
/// for format agnostic operators (like | |||
/// elemwise, elemwise multi type, | |||
/// typecvt, etc. | |||
const OprConfigTrait& | |||
m_opr_configs; /// the available opr format configurations, used | |||
/// for format aware operators (like conv, deconv, | |||
/// conv_bias, etc. | |||
Attribute m_attribute; /// the extra attributes to describe the problem | |||
const LayoutTransformContext& m_ctx; | |||
}; | |||
/*! | |||
@@ -170,6 +225,92 @@ public: | |||
static std::unique_ptr<ProfilerBase> make_profiler(); | |||
}; | |||
/*! | |||
* \brief abstract solver | |||
*/ | |||
class SolverBase { | |||
public: | |||
using OprFormat = Problem::OprFormat; | |||
using Solution = ThinHashMap<cg::OperatorNodeBase*, OprFormat>; | |||
SolverBase() = default; | |||
virtual ~SolverBase() = default; | |||
/*! | |||
* \brief solve the given problem | |||
*/ | |||
virtual Solution solve(const Problem& problem) const = 0; | |||
/*! | |||
* \brief check whether the given problem can be solved by the | |||
* algorithm(i.e. solver). | |||
*/ | |||
virtual bool can_solve(const Problem& problem) const = 0; | |||
}; | |||
/*! | |||
* \brief solvers that will first collect the costs of operators in different op | |||
* format and the costs of layout transform of varnode with a user provided | |||
* profiler on the target device. This will lead to time consuming. | |||
*/ | |||
class ProfilingBasedSolver : public SolverBase { | |||
public: | |||
using GraphPartitionFilter = | |||
thin_function<bool(const GraphPartition& graph_partition)>; | |||
ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler); | |||
/*! | |||
* \note some graph partition (for example, graph partition without format | |||
* aware operators like conv, deconv, warp, resize etc.) will be filtered by | |||
* the GraphPartitionFilter, which can reduce the profiling time. */ | |||
ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler, | |||
GraphPartitionFilter graph_partition_filter) | |||
: m_profiler{std::move(profiler)}, | |||
m_graph_partition_filter{std::move(graph_partition_filter)} {} | |||
virtual ~ProfilingBasedSolver() = default; | |||
Solution solve(const Problem& problem) const override; | |||
virtual Solution do_solve(const Problem& problem) const = 0; | |||
protected: | |||
std::unique_ptr<ProfilerBase> m_profiler; | |||
private: | |||
GraphPartitionFilter m_graph_partition_filter; | |||
}; | |||
/*! | |||
* \brief A solver that solves the layout selection problem using dynamic | |||
* programming algorithm (Markov decision process). | |||
*/ | |||
class DynamicProgrammingSolver final : public ProfilingBasedSolver { | |||
public: | |||
DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler) | |||
: ProfilingBasedSolver(std::move(profiler)){}; | |||
DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler, | |||
GraphPartitionFilter graph_partition_filter) | |||
: ProfilingBasedSolver(std::move(profiler), | |||
std::move(graph_partition_filter)){}; | |||
~DynamicProgrammingSolver() noexcept = default; | |||
Solution do_solve(const Problem& problem) const override; | |||
bool can_solve(const Problem& problem) const override; | |||
private: | |||
class Impl; | |||
}; | |||
/*! | |||
* \brief A layout transform pass, which convert the operator's format to the | |||
* optimal format using the results of the solver. | |||
*/ | |||
class LayoutTransformPass final : public Pass { | |||
public: | |||
const char* name() const override { return "layout assignment pass"; } | |||
void apply(OptState& opt) const override; | |||
LayoutTransformPass(std::unique_ptr<LayoutTransformContext> ctx, | |||
std::unique_ptr<SolverBase> solver) | |||
: m_ctx{std::move(ctx)}, m_solver{std::move(solver)} {} | |||
private: | |||
std::unique_ptr<LayoutTransformContext> m_ctx; | |||
std::unique_ptr<SolverBase> m_solver; | |||
}; | |||
} // namespace gopt | |||
} // namespace mgb | |||
@@ -84,7 +84,7 @@ public: | |||
output_dtype{DTypeEnum::Float32}, | |||
attribute{Attribute::DEFAULT} {} | |||
ReformatKey(TensorFormats input_format_, TensorFormats output_format_, | |||
Attribute attribute_ = Attribute::DEFAULT, | |||
Attribute attribute_, | |||
DTypeEnum input_dtype_ = DTypeEnum::Float32, | |||
DTypeEnum output_dtype_ = DTypeEnum::Float32) | |||
: input_format{input_format_}, | |||
@@ -92,6 +92,15 @@ public: | |||
input_dtype{input_dtype_}, | |||
output_dtype{output_dtype_}, | |||
attribute{attribute_} {} | |||
ReformatKey(TensorFormats input_format_, TensorFormats output_format_, | |||
DTypeEnum input_dtype_ = DTypeEnum::Float32, | |||
DTypeEnum output_dtype_ = DTypeEnum::Float32, | |||
Attribute attribute_ = Attribute::DEFAULT) | |||
: input_format{input_format_}, | |||
output_format{output_format_}, | |||
input_dtype{input_dtype_}, | |||
output_dtype{output_dtype_}, | |||
attribute{attribute_} {} | |||
struct Hash { | |||
size_t operator()(const ReformatKey& key) const; | |||
}; | |||
@@ -99,7 +108,6 @@ public: | |||
bool operator()(const ReformatKey& lhs, | |||
const ReformatKey& rhs) const; | |||
}; | |||
ReformatKey& deduce_reformat_dtype_enum(const DType& dt); | |||
}; | |||
using ReformatCache = | |||
std::unordered_map<ReformatKey, ReformatImpl, ReformatKey::Hash, | |||
@@ -130,6 +138,9 @@ TensorShape make_aligned_weight_shape(const VarNode* var, | |||
TensorFormats orig_formats, | |||
TensorFormats target_formats, | |||
TensorFormats extra_formats); | |||
ReformatManager::AlignmentDesc make_aligned_desc( | |||
TensorFormats weight_format, TensorFormats out_feature_format); | |||
} // namespace gopt | |||
} // namespace mgb | |||
@@ -20,6 +20,7 @@ class GraphPartition { | |||
public: | |||
using VarNodeSet = ThinHashSet<VarNode*>; | |||
using OperatorNodeSet = ThinHashSet<cg::OperatorNodeBase*>; | |||
using OperatorNodeList = std::vector<cg::OperatorNodeBase*>; | |||
class InputPlaceholder; | |||
@@ -32,15 +33,18 @@ public: | |||
const OperatorNodeSet& opr_set() const { return m_opr_set; } | |||
const VarNodeSet& input() const { return m_inputs; } | |||
const VarNodeSet& output() const { return m_outputs; } | |||
const OperatorNodeList& all_oprs() const { return m_oprs; } | |||
OperatorNodeSet& opr_set() { return m_opr_set; } | |||
OperatorNodeList& all_oprs() { return m_oprs; } | |||
VarNodeSet& input() { return m_inputs; } | |||
VarNodeSet& output() { return m_outputs; } | |||
private: | |||
std::pair<VarNodeArray, VarNodeArray> replace_graph_by_placeholder() const; | |||
OperatorNodeSet m_opr_set; | |||
OperatorNodeList m_oprs; | |||
VarNodeSet m_inputs; | |||
VarNodeSet m_outputs; | |||
std::pair<VarNodeArray, VarNodeArray> replace_graph_by_placeholder() const; | |||
}; | |||
class SubGraphExtractor { | |||
@@ -10,6 +10,7 @@ | |||
* implied. | |||
*/ | |||
#include "megbrain/plugin/profiler.h" | |||
#include "./helper.h" | |||
#include "megbrain/gopt/global_layout_transform.h" | |||
#include "megbrain/gopt/inference.h" | |||
@@ -22,123 +23,59 @@ using namespace mgb; | |||
using namespace gopt; | |||
using namespace serialization; | |||
#if MGB_CUDA | |||
namespace { | |||
class LayoutTransformContext : public NonCopyableObj { | |||
public: | |||
using OprList = SubGraphExtractor::OprList; | |||
using OprFormat = Problem::OprFormat; | |||
using OprConfigTrait = Problem::OprConfigTrait; | |||
LayoutTransformContext() = delete; | |||
LayoutTransformContext(OprList opr_list, | |||
SmallVector<TensorFormats> available_tensor_formats, | |||
OprConfigTrait opr_configs) | |||
: m_opr_list{std::move(opr_list)}, | |||
m_available_tensor_formats{std::move(available_tensor_formats)}, | |||
m_opr_configs{std::move(opr_configs)} {} | |||
const OprList& opr_list() const { return m_opr_list; } | |||
const SmallVector<TensorFormats>& available_tensor_formats() const { | |||
return m_available_tensor_formats; | |||
} | |||
const OprConfigTrait& opr_configs() const { return m_opr_configs; } | |||
static std::unique_ptr<LayoutTransformContext> make() { | |||
OprList opr_list = { | |||
opr::ConvBiasForward::typeinfo(), | |||
opr::ConvolutionForward::typeinfo(), | |||
opr::ConvolutionBackwardData::typeinfo(), | |||
opr::ElemwiseMultiType::typeinfo(), | |||
opr::Elemwise::typeinfo(), | |||
opr::TypeCvt::typeinfo(), | |||
opr::PoolingForward::typeinfo(), | |||
opr::WarpPerspectiveForward::typeinfo(), | |||
}; | |||
OprConfigTrait opr_configs; | |||
{ | |||
auto& dispatchers = opr_configs[opr::ConvBias::typeinfo()]; | |||
#define cb(_fmt) \ | |||
dispatchers[OprFormat::_fmt] = \ | |||
OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ | |||
opr::ConvBias::typeinfo(), OprFormat::_fmt); | |||
cb(NCHW4); | |||
cb(NCHW32); | |||
cb(NHWC); | |||
cb(NCHW64); | |||
cb(CHWN4); | |||
#undef cb | |||
} | |||
{ | |||
auto& dispatchers = | |||
opr_configs[opr::ConvolutionBackwardData::typeinfo()]; | |||
#define cb(_fmt) \ | |||
dispatchers[OprFormat::_fmt] = \ | |||
OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ | |||
opr::ConvolutionBackwardData::typeinfo(), \ | |||
OprFormat::_fmt); | |||
cb(NCHW4); | |||
#undef cb | |||
} | |||
{ | |||
auto& dispatchers = | |||
opr_configs[opr::ConvolutionForward::typeinfo()]; | |||
#define cb(_fmt) \ | |||
dispatchers[OprFormat::_fmt] = \ | |||
OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ | |||
opr::ConvolutionForward::typeinfo(), OprFormat::_fmt); | |||
cb(NCHW4); | |||
#undef cb | |||
} | |||
{ | |||
auto& dispatchers = opr_configs[opr::PoolingForward::typeinfo()]; | |||
#define cb(_fmt) \ | |||
dispatchers[OprFormat::_fmt] = \ | |||
OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ | |||
opr::PoolingForward::typeinfo(), OprFormat::_fmt); | |||
cb(NCHW4); | |||
cb(NCHW32); | |||
cb(NHWC); | |||
cb(NCHW64); | |||
cb(CHWN4); | |||
#undef cb | |||
} | |||
{ | |||
auto& dispatchers = | |||
opr_configs[opr::WarpPerspectiveForward::typeinfo()]; | |||
#define cb(_fmt) \ | |||
dispatchers[OprFormat::_fmt] = \ | |||
OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ | |||
opr::WarpPerspectiveForward::typeinfo(), OprFormat::_fmt); | |||
cb(NHWC); | |||
cb(NCHW4); | |||
cb(NCHW64); | |||
#undef cb | |||
} | |||
SmallVector<TensorFormats> available_tensor_formats = { | |||
TensorFormats::NHWC, TensorFormats::NCHWc4, | |||
TensorFormats::NCHWc32, TensorFormats::NCHWc64}; | |||
return std::make_unique<LayoutTransformContext>( | |||
std::move(opr_list), std::move(available_tensor_formats), | |||
std::move(opr_configs)); | |||
} | |||
std::unique_ptr<LayoutTransformContext> make_ctx() { | |||
using OprFormat = LayoutTransformContext::OprFormat; | |||
using OprList = LayoutTransformContext::OprList; | |||
using ReformatAttribute = LayoutTransformContext::ReformatAttribute; | |||
using Attribute = LayoutTransformContext::Attribute; | |||
OprList opr_list = { | |||
opr::ConvBiasForward::typeinfo(), | |||
opr::ConvolutionForward::typeinfo(), | |||
opr::ConvolutionBackwardData::typeinfo(), | |||
opr::ElemwiseMultiType::typeinfo(), | |||
opr::Elemwise::typeinfo(), | |||
opr::TypeCvt::typeinfo(), | |||
opr::PoolingForward::typeinfo(), | |||
opr::WarpPerspectiveForward::typeinfo(), | |||
}; | |||
private: | |||
OprList m_opr_list; | |||
SmallVector<TensorFormats> m_available_tensor_formats; | |||
OprConfigTrait m_opr_configs; | |||
}; | |||
}; // namespace | |||
SmallVector<TensorFormats> available_tensor_formats = { | |||
TensorFormats::NCHW, TensorFormats::NHWC, | |||
TensorFormats::NCHWc4, TensorFormats::NCHWc32, | |||
TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | |||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, | |||
ReformatAttribute::DEFAULT}; | |||
auto ctx = std::make_unique<LayoutTransformContext>( | |||
std::move(opr_list), std::move(available_tensor_formats), | |||
attribute); | |||
ctx->add_opr_config( | |||
opr::ConvBiasForward::typeinfo(), | |||
{OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, | |||
OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) | |||
.add_opr_config(opr::ConvolutionForward::typeinfo(), | |||
{OprFormat::NCHW, OprFormat::NCHW4}) | |||
.add_opr_config(opr::ConvolutionBackwardData::typeinfo(), | |||
{OprFormat::NCHW, OprFormat::NCHW4}) | |||
.add_opr_config( | |||
opr::PoolingForward::typeinfo(), | |||
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | |||
OprFormat::NCHW64, OprFormat::CHWN4}) | |||
.add_opr_config( | |||
opr::WarpPerspectiveForward::typeinfo(), | |||
{OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); | |||
return ctx; | |||
} | |||
} // namespace | |||
#if MGB_CUDA | |||
#if CUDA_VERSION >= 10020 | |||
TEST(TestProfiler, Conv) { | |||
REQUIRE_GPU(1); | |||
auto cn = CompNode::load("gpu0"); | |||
cn.activate(); | |||
REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); | |||
auto ctx = LayoutTransformContext::make(); | |||
auto ctx = make_ctx(); | |||
HostTensorGenerator<dtype::Int8> gen; | |||
auto graph = ComputingGraph::make(); | |||
@@ -177,14 +114,10 @@ TEST(TestProfiler, Conv) { | |||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
S strategy = S::PROFILE; | |||
gopt::modify_opr_algo_strategy_inplace({c2}, strategy); | |||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||
SubGraphExtractor extractor(ctx->opr_list()); | |||
auto partitions = extractor.extract({c2}); | |||
ASSERT_EQ(partitions.size(), 1u); | |||
using Attribute = Problem::Attribute; | |||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; | |||
Problem problem(partitions[0], ctx->available_tensor_formats(), | |||
ctx->opr_configs(), attribute); | |||
Problem problem(partitions[0], *ctx); | |||
auto profiler = ProfilerBase::make_profiler(); | |||
auto rst = profiler->profile(problem); | |||
const auto& opr_rst = rst.opr_record; | |||
@@ -204,7 +137,7 @@ TEST(TestProfiler, Deconv) { | |||
auto cn = CompNode::load("gpu0"); | |||
cn.activate(); | |||
REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); | |||
auto ctx = LayoutTransformContext::make(); | |||
auto ctx = make_ctx(); | |||
HostTensorGenerator<dtype::Int8> gen; | |||
auto graph = ComputingGraph::make(); | |||
@@ -238,14 +171,10 @@ TEST(TestProfiler, Deconv) { | |||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
S strategy = S::PROFILE; | |||
gopt::modify_opr_algo_strategy_inplace({c2}, strategy); | |||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||
SubGraphExtractor extractor(ctx->opr_list()); | |||
auto partitions = extractor.extract({c2}); | |||
ASSERT_EQ(partitions.size(), 1u); | |||
using Attribute = Problem::Attribute; | |||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; | |||
Problem problem(partitions[0], ctx->available_tensor_formats(), | |||
ctx->opr_configs(), attribute); | |||
Problem problem(partitions[0], *ctx); | |||
auto profiler = ProfilerBase::make_profiler(); | |||
auto rst = profiler->profile(problem); | |||
const auto& opr_rst = rst.opr_record; | |||
@@ -262,7 +191,7 @@ TEST(TestProfiler, Warp) { | |||
auto cn = CompNode::load("gpu0"); | |||
cn.activate(); | |||
REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); | |||
auto ctx = LayoutTransformContext::make(); | |||
auto ctx = make_ctx(); | |||
constexpr size_t INP_H = 10, INP_W = 10, N = 16; | |||
@@ -307,14 +236,9 @@ TEST(TestProfiler, Warp) { | |||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
S strategy = S::PROFILE; | |||
gopt::modify_opr_algo_strategy_inplace({w1}, strategy); | |||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||
SubGraphExtractor extractor(ctx->opr_list()); | |||
auto partitions = extractor.extract({w1}); | |||
ASSERT_EQ(partitions.size(), 1u); | |||
using Attribute = Problem::Attribute; | |||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; | |||
Problem problem(partitions[0], ctx->available_tensor_formats(), | |||
ctx->opr_configs(), attribute); | |||
Problem problem(partitions[0], *ctx); | |||
auto profiler = ProfilerBase::make_profiler(); | |||
auto rst = profiler->profile(problem); | |||
const auto& opr_rst = rst.opr_record; | |||
@@ -330,7 +254,7 @@ TEST(TestProfiler, Pooling) { | |||
auto cn = CompNode::load("gpu0"); | |||
cn.activate(); | |||
REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); | |||
auto ctx = LayoutTransformContext::make(); | |||
auto ctx = make_ctx(); | |||
HostTensorGenerator<dtype::Int8> gen; | |||
auto graph = ComputingGraph::make(); | |||
@@ -353,14 +277,10 @@ TEST(TestProfiler, Pooling) { | |||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
S strategy = S::PROFILE; | |||
gopt::modify_opr_algo_strategy_inplace({p2}, strategy); | |||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||
SubGraphExtractor extractor(ctx->opr_list()); | |||
auto partitions = extractor.extract({p2}); | |||
ASSERT_EQ(partitions.size(), 1u); | |||
using Attribute = Problem::Attribute; | |||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; | |||
Problem problem(partitions[0], ctx->available_tensor_formats(), | |||
ctx->opr_configs(), attribute); | |||
Problem problem(partitions[0], *ctx); | |||
auto profiler = ProfilerBase::make_profiler(); | |||
auto rst = profiler->profile(problem); | |||
const auto& opr_rst = rst.opr_record; | |||
@@ -373,8 +293,7 @@ TEST(TestProfiler, Elemwise) { | |||
REQUIRE_GPU(1); | |||
auto cn = CompNode::load("gpu0"); | |||
cn.activate(); | |||
REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); | |||
auto ctx = LayoutTransformContext::make(); | |||
auto ctx = make_ctx(); | |||
HostTensorGenerator<dtype::Int8> gen; | |||
auto graph = ComputingGraph::make(); | |||
@@ -403,14 +322,10 @@ TEST(TestProfiler, Elemwise) { | |||
OperatorNodeConfig( | |||
dtype::Quantized4Asymm(13.f, static_cast<uint8_t>(4)))); | |||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||
SubGraphExtractor extractor(ctx->opr_list()); | |||
auto partitions = extractor.extract({q4e}); | |||
ASSERT_EQ(partitions.size(), 1u); | |||
using Attribute = Problem::Attribute; | |||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; | |||
Problem problem(partitions[0], ctx->available_tensor_formats(), | |||
ctx->opr_configs(), attribute); | |||
Problem problem(partitions[0], *ctx); | |||
auto profiler = ProfilerBase::make_profiler(); | |||
auto rst = profiler->profile(problem); | |||
const auto& opr_rst = rst.opr_record; | |||
@@ -423,7 +338,6 @@ TEST(TestProfiler, Elemwise) { | |||
EXPECT_TRUE(var_rst.count(q8a.node()) > 0); | |||
EXPECT_TRUE(var_rst.count(q8b.node()) > 0); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -447,6 +447,7 @@ TEST(TestReformatManager, AutoAlignedFeatureProfiling) { | |||
for (size_t i = 0; i < RUNS; ++i) | |||
func->execute(); | |||
double time_profiler = profiler->duration() * 1e6; | |||
printf("time: %f, %f\n", time_cuda_evt, time_profiler); | |||
MGB_CUDA_CHECK(cudaEventDestroy(evt0)); | |||
MGB_CUDA_CHECK(cudaEventDestroy(evt1)); | |||
} | |||