GitOrigin-RevId: 595392ec89
release-1.6
@@ -0,0 +1,547 @@ | |||||
/** | |||||
* \file src/gopt/impl/dynamic_programming_solver.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include <queue> | |||||
#include "./utils.h" | |||||
#include "megbrain/gopt/global_layout_transform.h" | |||||
using namespace mgb; | |||||
using namespace gopt; | |||||
using namespace cg; | |||||
/* ================= DynamicProgrammingSolver::Impl ==================*/ | |||||
class DynamicProgrammingSolver::Impl { | |||||
public: | |||||
Impl(size_t max_states) : m_max_states{max_states} {} | |||||
~Impl() = default; | |||||
Solution solve(const ProfilerBase* profiler, const Problem& problem); | |||||
private: | |||||
using TensorFormatsBitSet = uint32_t; | |||||
using State = SmallVector<TensorFormatsBitSet>; | |||||
static constexpr uint32_t MAX_TENSOR_FORMATS = sizeof(TensorFormatsBitSet); | |||||
TensorFormatsBitSet add(TensorFormatsBitSet& set, TensorFormats fmt) { | |||||
mgb_assert(static_cast<uint32_t>(fmt) < MAX_TENSOR_FORMATS); | |||||
set |= (1 << static_cast<uint32_t>(fmt)); | |||||
return set; | |||||
} | |||||
bool valid(const TensorFormatsBitSet& set, TensorFormats fmt) { | |||||
mgb_assert(static_cast<uint32_t>(fmt) < MAX_TENSOR_FORMATS); | |||||
bool val = set & (1 << static_cast<uint32_t>(fmt)); | |||||
return val; | |||||
} | |||||
struct Value { | |||||
OperatorNodeBase* opr; | |||||
const State* prev; | |||||
OprFormat opr_fmt; | |||||
float time; | |||||
///! index in the topo order of the correspoding operator | |||||
size_t opr_idx; | |||||
}; | |||||
struct StateHash { | |||||
size_t operator()(const State& key) const { | |||||
size_t h = 0; | |||||
for (auto&& v : key) { | |||||
h = mgb::hash_pair_combine(h, | |||||
std::hash<TensorFormatsBitSet>{}(v)); | |||||
} | |||||
return h; | |||||
} | |||||
}; | |||||
struct StateEqual { | |||||
size_t operator()(const State& lhs, const State& rhs) const { | |||||
if (lhs.size() != rhs.size()) | |||||
return false; | |||||
for (size_t i = 0; i < lhs.size(); ++i) { | |||||
if (lhs[i] != rhs[i]) | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
}; | |||||
using StateTable = std::unordered_map<State, Value, StateHash, StateEqual>; | |||||
struct Cut { | |||||
StateTable states; | |||||
}; | |||||
using ProfilingResult = ProfilerBase::ProfilingResult; | |||||
using OprConfigTrait = LayoutTransformContext::OprConfigTrait; | |||||
struct Context { | |||||
const std::vector<OperatorNodeBase*>& topo; | |||||
const ProfilingResult& rst; | |||||
const OprConfigTrait& opr_configs; | |||||
const SmallVector<TensorFormats>& available_tensor_formats; | |||||
}; | |||||
/*! | |||||
* \brief get the tensor formats configuration for the operator with particular op format | |||||
* \param[out] var2fmts hashmap that maps varnode to actual tensor formats of the op format configuration | |||||
* \param[in] opr given operator | |||||
* \param[in] opr_fmt given op format, an enum type argument which indicates the op format configuration. | |||||
* \param[in] ctx context | |||||
*/ | |||||
TensorFormats get_io_formats(ThinHashMap<VarNode*, TensorFormats>& var2fmts, | |||||
const OperatorNodeBase* opr, OprFormat opr_fmt, | |||||
const Context& ctx); | |||||
/*! | |||||
* \brief compute the distace of two states of the given varnode | |||||
* \param[in] from the source state | |||||
* \param[in] to the target state | |||||
* \param[in] var given varnode | |||||
* \param[in] ctx context | |||||
*/ | |||||
float distance(const TensorFormatsBitSet& from, | |||||
const TensorFormatsBitSet& to, VarNode* var, | |||||
const Context& ctx); | |||||
/*! | |||||
* \brief compute the distace of two states of the given cut edges | |||||
* \param[in] from the source state | |||||
* \param[in] to the target state | |||||
* \param[in] edge a VarNodeArry, the given cut edges | |||||
* \param[in] ctx context | |||||
*/ | |||||
float state_distance(const State& from, const State& to, | |||||
const VarNodeArray& edge, const Context& ctx); | |||||
/*! | |||||
* \brief analyze the edges of each cut | |||||
* \param[out] edges the return edges of the cuts | |||||
* \param[out] edge2idx hashmaps, that maps edge(varnode) to its index | |||||
* \param[in] ctx contex | |||||
*/ | |||||
void analyze_edges(SmallVector<VarNodeArray>& edges, | |||||
SmallVector<std::unordered_map<VarNode*, int>>& edge2idx, | |||||
const Context& ctx); | |||||
/*! | |||||
* \brief prune states using the distance of states | |||||
*/ | |||||
void prune(StateTable& states, const VarNodeArray& edge, | |||||
const Context& ctx); | |||||
/*! | |||||
* \brief force prune states, reserve the smallest MAX_STATES states | |||||
*/ | |||||
void force_prune(StateTable& states); | |||||
private: | |||||
size_t m_max_states; | |||||
}; | |||||
TensorFormats DynamicProgrammingSolver::Impl::get_io_formats( | |||||
ThinHashMap<VarNode*, TensorFormats>& var2fmts, | |||||
const OperatorNodeBase* opr, OprFormat opr_fmt, const Context& ctx) { | |||||
auto&& rst = ctx.rst; | |||||
auto&& opr_configs = ctx.opr_configs; | |||||
auto iter = opr_configs.find(opr->dyn_typeinfo()); | |||||
Maybe<OprTensorFormatsConfiguration> fmtcfg = None; | |||||
if (iter != opr_configs.end()) { | |||||
fmtcfg = (*iter->second.at(opr_fmt))(opr); | |||||
} | |||||
TensorFormats out_fmt; | |||||
if (fmtcfg.valid()) | |||||
out_fmt = fmtcfg.val().output_tensor_formats[0]; | |||||
else | |||||
out_fmt = opr_format_to_tensor_formats(opr_fmt); | |||||
for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
auto&& var = opr->input(i); | |||||
auto iter = rst.var_record.find(var); | |||||
if (iter != rst.var_record.end()) { | |||||
if (fmtcfg.valid()) | |||||
var2fmts[var] = fmtcfg.val().input_tensor_formats[i]; | |||||
else | |||||
var2fmts[var] = opr_format_to_tensor_formats(opr_fmt); | |||||
} | |||||
} | |||||
return out_fmt; | |||||
} | |||||
float DynamicProgrammingSolver::Impl::distance(const TensorFormatsBitSet& from, | |||||
const TensorFormatsBitSet& to, | |||||
VarNode* var, | |||||
const Context& ctx) { | |||||
auto&& costs = ctx.rst.var_record.at(var).costs; | |||||
auto&& available_tensor_formats = ctx.available_tensor_formats; | |||||
float dist = 0.f; | |||||
if ((from & to) == to) | |||||
return dist; | |||||
auto to_set = ((from | to) ^ from); | |||||
for (auto o : available_tensor_formats) { | |||||
if (valid(to_set, o)) { | |||||
float o_cost = std::numeric_limits<float>::max(); | |||||
for (auto i : available_tensor_formats) { | |||||
if (valid(from, i)) { | |||||
float cost = costs.at({i, o}); | |||||
o_cost = std::min(o_cost, cost); | |||||
} | |||||
} | |||||
dist += o_cost; | |||||
} | |||||
} | |||||
return dist; | |||||
} | |||||
float DynamicProgrammingSolver::Impl::state_distance(const State& from, | |||||
const State& to, | |||||
const VarNodeArray& edge, | |||||
const Context& ctx) { | |||||
float dist = 0.f; | |||||
mgb_assert(from.size() == to.size() && from.size() == edge.size()); | |||||
for (size_t i = 0; i < edge.size(); ++i) { | |||||
dist += distance(from[i], to[i], edge[i], ctx); | |||||
} | |||||
return dist; | |||||
} | |||||
void DynamicProgrammingSolver::Impl::analyze_edges( | |||||
SmallVector<VarNodeArray>& edges, | |||||
SmallVector<std::unordered_map<VarNode*, int>>& edge2idx, | |||||
const Context& ctx) { | |||||
auto&& topo = ctx.topo; | |||||
auto&& rst = ctx.rst; | |||||
size_t nr_oprs = topo.size(); | |||||
edges.resize(nr_oprs); | |||||
edge2idx.resize(nr_oprs); | |||||
ThinHashSet<VarNode*> cur_edge; | |||||
size_t cur = nr_oprs - 1; | |||||
int idx = 0; | |||||
for (auto&& ov : topo[cur]->usable_output()) { | |||||
edges[cur].push_back(ov); | |||||
edge2idx[cur].emplace(ov, idx++); | |||||
} | |||||
cur--; | |||||
for (const auto& opr : reverse_adaptor(topo)) { | |||||
for (const auto& i : opr->input()) { | |||||
if (rst.var_record.count(i) > 0) { | |||||
cur_edge.insert(i); | |||||
} | |||||
} | |||||
for (auto&& ov : opr->usable_output()) { | |||||
cur_edge.erase(ov); | |||||
} | |||||
edges[cur].insert(edges[cur].begin(), cur_edge.begin(), cur_edge.end()); | |||||
int i = 0; | |||||
for (auto&& e : edges[cur]) { | |||||
edge2idx[cur][e] = i++; | |||||
} | |||||
if (cur == 0) | |||||
break; | |||||
cur--; | |||||
} | |||||
} | |||||
void DynamicProgrammingSolver::Impl::prune(StateTable& states, | |||||
const VarNodeArray& edge, | |||||
const Context& ctx) { | |||||
struct Item { | |||||
decltype(states.begin()) iter; | |||||
}; | |||||
std::list<Item> list; | |||||
for (auto it = states.begin(); it != states.end(); ++it) { | |||||
list.emplace_back(Item{it}); | |||||
} | |||||
SmallVector<State> removed_states; | |||||
for (auto i = list.begin(); i != list.end();) { | |||||
bool advanced_i = false; | |||||
for (auto j = std::next(i, 1); j != list.end();) { | |||||
if (i->iter->second.time > j->iter->second.time && | |||||
state_distance(j->iter->first, i->iter->first, edge, ctx) < | |||||
i->iter->second.time - j->iter->second.time) { | |||||
removed_states.push_back(i->iter->first); | |||||
i = list.erase(i); | |||||
advanced_i = true; | |||||
break; | |||||
} else if (i->iter->second.time < j->iter->second.time && | |||||
state_distance(i->iter->first, j->iter->first, edge, | |||||
ctx) < | |||||
j->iter->second.time - i->iter->second.time) { | |||||
removed_states.push_back(j->iter->first); | |||||
j = list.erase(j); | |||||
} else { | |||||
j = std::next(j, 1); | |||||
} | |||||
} | |||||
if (!advanced_i) | |||||
i = std::next(i, 1); | |||||
} | |||||
for (auto&& state : removed_states) | |||||
states.erase(state); | |||||
} | |||||
void DynamicProgrammingSolver::Impl::force_prune(StateTable& states) { | |||||
if (states.size() < m_max_states) | |||||
return; | |||||
struct Item { | |||||
decltype(states.begin()) iter; | |||||
}; | |||||
auto cmp = [](Item lhs, Item rhs) { | |||||
return lhs.iter->second.time < rhs.iter->second.time; | |||||
}; | |||||
std::priority_queue<Item, std::vector<Item>, decltype(cmp)> pq(cmp); | |||||
for (auto it = states.begin(); it != states.end(); ++it) { | |||||
if (pq.size() < m_max_states) | |||||
pq.push(Item{it}); | |||||
else { | |||||
auto i = pq.top(); | |||||
if (it->second.time < i.iter->second.time) { | |||||
pq.pop(); | |||||
pq.push(Item{it}); | |||||
} | |||||
} | |||||
} | |||||
StateTable active_state; | |||||
while (!pq.empty()) { | |||||
auto i = pq.top(); | |||||
active_state.insert(*i.iter); | |||||
pq.pop(); | |||||
} | |||||
states.swap(active_state); | |||||
} | |||||
DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( | |||||
const ProfilerBase* profiler, const Problem& problem) { | |||||
const auto rst = profiler->profile(problem); | |||||
const auto& partition = problem.graph_partition(); | |||||
const auto& opr_configs = problem.opr_configs(); | |||||
const auto& base_fmt = problem.base_format(); | |||||
const auto& available_tensor_formats = problem.available_tensor_formats(); | |||||
const auto& topo = partition.all_oprs(); | |||||
Context ctx{topo, rst, opr_configs, available_tensor_formats}; | |||||
SmallVector<VarNodeArray> edges; | |||||
SmallVector<std::unordered_map<VarNode*, int>> edge2idx; | |||||
/// analyze edges of each cuts | |||||
analyze_edges(edges, edge2idx, ctx); | |||||
SmallVector<Cut> cuts; | |||||
size_t cur = 0; | |||||
/// initialize states | |||||
auto init = [&, this](OperatorNodeBase* opr) { | |||||
auto it = rst.opr_record.find(opr); | |||||
if (it == rst.opr_record.end()) | |||||
return; | |||||
ThinHashSet<VarNode*> ovar_set; | |||||
for (auto&& ov : opr->usable_output()) { | |||||
ovar_set.insert(ov); | |||||
} | |||||
const auto& records = it->second.costs; | |||||
cuts.emplace_back(Cut{}); | |||||
auto& states = cuts.back().states; | |||||
for (const auto& record : records) { | |||||
auto opr_fmt = record.first; | |||||
float opr_time = record.second; | |||||
ThinHashMap<VarNode*, TensorFormats> ivar2fmts; | |||||
auto out_fmt = get_io_formats(ivar2fmts, opr, opr_fmt, ctx); | |||||
const auto& edge = edges[cur]; | |||||
State state(edge.size(), 0); | |||||
Value value{opr, nullptr, opr_fmt, 0.f, cur}; | |||||
float ovar_time = 0.f; | |||||
for (size_t i = 0; i < edge.size(); ++i) { | |||||
auto&& var = edge[i]; | |||||
auto&& costs = rst.var_record.at(var).costs; | |||||
if (ovar_set.count(var) > 0) { | |||||
add(state[i], out_fmt); | |||||
if (partition.output().count(var) > 0 && | |||||
out_fmt != base_fmt) { | |||||
ovar_time += costs.at({out_fmt, base_fmt}); | |||||
add(state[i], base_fmt); | |||||
} | |||||
} else { | |||||
add(state[i], base_fmt); | |||||
} | |||||
} | |||||
float ivar_time = 0.f; | |||||
for (const auto& kv : ivar2fmts) { | |||||
auto&& v = kv.first; | |||||
auto&& costs = rst.var_record.at(v).costs; | |||||
auto to = kv.second; | |||||
float min_time = std::numeric_limits<float>::max(); | |||||
if (base_fmt == to) { | |||||
min_time = 0.f; | |||||
} else { | |||||
min_time = costs.at({base_fmt, to}); | |||||
if (edge2idx[cur].count(v) > 0) { | |||||
add(state[edge2idx[cur][v]], to); | |||||
} | |||||
} | |||||
ivar_time += min_time; | |||||
} | |||||
value.time = opr_time + ivar_time + ovar_time; | |||||
states[state] = value; | |||||
} | |||||
}; | |||||
/// update the states | |||||
auto body = [&, this](OperatorNodeBase* opr) { | |||||
auto it = rst.opr_record.find(opr); | |||||
if (it == rst.opr_record.end()) | |||||
return; | |||||
ThinHashSet<VarNode*> ovar_set; | |||||
for (auto&& ov : opr->usable_output()) { | |||||
ovar_set.insert(ov); | |||||
} | |||||
const auto& records = it->second.costs; | |||||
StateTable states; | |||||
for (const auto& record : records) { | |||||
auto opr_fmt = record.first; | |||||
float opr_time = record.second; | |||||
ThinHashMap<VarNode*, TensorFormats> ivar2fmts; | |||||
auto out_fmt = get_io_formats(ivar2fmts, opr, opr_fmt, ctx); | |||||
for (const auto& kv : cuts.back().states) { | |||||
auto&& prev_state = kv.first; | |||||
float prev_time = kv.second.time; | |||||
const auto& edge = edges[cur]; | |||||
State state(edge.size(), 0); | |||||
Value value{opr, &prev_state, opr_fmt, 0.f, cur}; | |||||
float ovar_time = 0.f; | |||||
for (size_t i = 0; i < edge.size(); ++i) { | |||||
auto&& var = edge[i]; | |||||
auto&& costs = rst.var_record.at(var).costs; | |||||
auto iter = edge2idx[cur - 1].find(var); | |||||
if (iter != edge2idx[cur - 1].end()) { | |||||
state[i] = prev_state[iter->second]; | |||||
} else { | |||||
mgb_assert(ovar_set.count(var) > 0); | |||||
add(state[i], out_fmt); | |||||
if (partition.output().count(var) > 0 && | |||||
out_fmt != base_fmt) { | |||||
ovar_time += costs.at({out_fmt, base_fmt}); | |||||
add(state[i], base_fmt); | |||||
} | |||||
} | |||||
} | |||||
float ivar_time = 0.f; | |||||
for (const auto& kv : ivar2fmts) { | |||||
auto&& v = kv.first; | |||||
auto&& costs = rst.var_record.at(v).costs; | |||||
auto to = kv.second; | |||||
auto it1 = edge2idx[cur - 1].find(v); | |||||
float min_time = std::numeric_limits<float>::max(); | |||||
if (valid(prev_state[it1->second], to)) { | |||||
min_time = 0.f; | |||||
} else { | |||||
for (auto&& from : available_tensor_formats) { | |||||
if (valid(prev_state[it1->second], from)) { | |||||
float cost = costs.at({from, to}); | |||||
min_time = std::min(min_time, cost); | |||||
} | |||||
} | |||||
} | |||||
auto it2 = edge2idx[cur].find(v); | |||||
if (it2 != edge2idx[cur].end()) { | |||||
add(state[it2->second], to); | |||||
} | |||||
ivar_time += min_time; | |||||
} | |||||
value.time = prev_time + opr_time + ivar_time + ovar_time; | |||||
auto iter = states.find(state); | |||||
if (iter == states.end()) { | |||||
states[state] = value; | |||||
} else { | |||||
float time = iter->second.time; | |||||
if (value.time < time) { | |||||
iter->second = value; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
cuts.emplace_back(Cut{}); | |||||
cuts.back().states.swap(states); | |||||
}; | |||||
/// forward pass to generate all states | |||||
for (auto&& opr : topo) { | |||||
if (cuts.empty()) { | |||||
init(opr); | |||||
} else { | |||||
body(opr); | |||||
} | |||||
if (!cuts.empty()) { | |||||
auto& states = cuts.back().states; | |||||
prune(states, edges[cur], ctx); | |||||
force_prune(states); | |||||
} | |||||
cur++; | |||||
} | |||||
Solution solution; | |||||
/// backward pass to generate the solution | |||||
float min_time = std::numeric_limits<float>::max(); | |||||
OperatorNodeBase* cur_opr; | |||||
OprFormat min_fmt; | |||||
const State* pstate = nullptr; | |||||
for (auto&& kv : cuts.back().states) { | |||||
auto&& v = kv.second; | |||||
if (v.time < min_time) { | |||||
cur_opr = v.opr; | |||||
pstate = v.prev; | |||||
min_time = v.time; | |||||
min_fmt = v.opr_fmt; | |||||
///! just to check the tensor formats of the output varnode | |||||
auto&& k = kv.first; | |||||
size_t opr_idx = v.opr_idx; | |||||
for (size_t i = 0; i < k.size(); ++i) { | |||||
auto&& fmt_set = k[i]; | |||||
auto&& var = edges[opr_idx][i]; | |||||
if (partition.output().count(var)) { | |||||
mgb_assert(valid(fmt_set, base_fmt)); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
mgb_log_debug("opr:%s;format:%s;time:%f", cur_opr->cname(), | |||||
opr_format_to_string(min_fmt), min_time); | |||||
solution.insert({cur_opr, min_fmt}); | |||||
cur = cuts.size() - 2; | |||||
while (pstate) { | |||||
auto val = cuts[cur].states[*pstate]; | |||||
///! just to check the tensor formats of the output varnode | |||||
size_t opr_idx = val.opr_idx; | |||||
for (size_t i = 0; i < pstate->size(); ++i) { | |||||
auto&& fmt_set = pstate->operator[](i); | |||||
auto&& var = edges[opr_idx][i]; | |||||
if (partition.output().count(var)) { | |||||
mgb_assert(valid(fmt_set, base_fmt)); | |||||
} | |||||
} | |||||
mgb_log_debug("opr:%s;format:%s;time:%f", val.opr->cname(), | |||||
opr_format_to_string(val.opr_fmt), val.time); | |||||
solution.insert({val.opr, val.opr_fmt}); | |||||
pstate = val.prev; | |||||
cur--; | |||||
} | |||||
return solution; | |||||
} | |||||
/* =================== DynamicProgrammingSolver ======================*/ | |||||
DynamicProgrammingSolver::Solution DynamicProgrammingSolver::do_solve( | |||||
const Problem& problem) const { | |||||
constexpr size_t MAX_STATES = 1024; | |||||
Impl impl(MAX_STATES); | |||||
return impl.solve(m_profiler.get(), problem); | |||||
} | |||||
bool DynamicProgrammingSolver::can_solve(const Problem& problem) const { | |||||
auto&& available_tensor_formats = problem.available_tensor_formats(); | |||||
for (auto&& tensor_format : available_tensor_formats) { | |||||
if (static_cast<uint32_t>(tensor_format) >= 32) | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,40 @@ | |||||
/** | |||||
* \file src/gopt/impl/layout_transform_context.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "./utils.h" | |||||
#include "megbrain/gopt/global_layout_transform.h" | |||||
using namespace mgb; | |||||
using namespace gopt; | |||||
/* ================= LayoutTransformContext ==================*/ | |||||
LayoutTransformContext& LayoutTransformContext::add_opr_config( | |||||
Typeinfo* opr, OprFormat opr_format) { | |||||
auto& dispatchers = m_opr_configs[opr]; | |||||
dispatchers[opr_format] = | |||||
OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | |||||
opr, opr_format); | |||||
return *this; | |||||
} | |||||
LayoutTransformContext& LayoutTransformContext::add_opr_config( | |||||
Typeinfo* opr, SmallVector<OprFormat> opr_formats) { | |||||
auto& dispatchers = m_opr_configs[opr]; | |||||
for (auto opr_fmt : opr_formats) { | |||||
dispatchers[opr_fmt] = | |||||
OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | |||||
opr, opr_fmt); | |||||
} | |||||
return *this; | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -17,6 +17,7 @@ | |||||
#include "megbrain/graph/event.h" | #include "megbrain/graph/event.h" | ||||
#include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
#include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
#include "megbrain/opr/nn_int.h" | |||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
#include "megbrain/plugin/base.h" | #include "megbrain/plugin/base.h" | ||||
#include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
@@ -265,6 +266,10 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( | |||||
record.opr = opr; | record.opr = opr; | ||||
auto& costs = record.costs; | auto& costs = record.costs; | ||||
for (auto&& i : available_configs) { | for (auto&& i : available_configs) { | ||||
/// XXXX remove later | |||||
if (i.opr_format == OprFormat::NCHW && | |||||
opr->input(0)->dtype().enumv() != DTypeEnum::Float32) | |||||
continue; | |||||
costs[i.opr_format] = profile_operator(opr, base_config, i); | costs[i.opr_format] = profile_operator(opr, base_config, i); | ||||
} | } | ||||
return record; | return record; | ||||
@@ -414,37 +419,42 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( | |||||
cb(Resize, 1), | cb(Resize, 1), | ||||
#undef cb | #undef cb | ||||
}; | }; | ||||
static const ThinHashSet<Typeinfo*> skip_opr_types = { | |||||
TypeCvt::typeinfo(), Elemwise::typeinfo(), | |||||
ElemwiseMultiType::typeinfo()}; | |||||
ThinHashSet<VarNode*> vars; | ThinHashSet<VarNode*> vars; | ||||
ThinHashSet<OperatorNodeBase*> oprs; | ThinHashSet<OperatorNodeBase*> oprs; | ||||
{ | |||||
auto cb = [&cvprop, &vars, &oprs](OperatorNodeBase* opr) { | |||||
if (cvprop.is_const(opr)) | |||||
return; | |||||
oprs.insert(opr); | |||||
auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); | |||||
if (find == format_aware_input_tensors.end()) { | |||||
for (auto&& i : opr->input()) { | |||||
if (!cvprop.is_const(i)) { | |||||
vars.insert(i); | |||||
} | |||||
ThinHashSet<OperatorNodeBase*> skip_oprs; | |||||
for (auto&& opr : problem.graph_partition().all_oprs()) { | |||||
if (cvprop.is_const(opr)) | |||||
continue; | |||||
bool skip = true; | |||||
for (auto&& i : opr->input()) { | |||||
skip &= problem.graph_partition().input().count(i) > 0 || | |||||
skip_oprs.count(i->owner_opr()) > 0; | |||||
} | |||||
skip &= skip_opr_types.count(opr->dyn_typeinfo()); | |||||
if (skip) | |||||
skip_oprs.insert(opr); | |||||
oprs.insert(opr); | |||||
auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); | |||||
if (find == format_aware_input_tensors.end()) { | |||||
for (auto&& i : opr->input()) { | |||||
if (!cvprop.is_const(i)) { | |||||
vars.insert(i); | |||||
} | } | ||||
} else { | |||||
size_t nr_input_tensor = | |||||
std::min(find->second, opr->input().size()); | |||||
for (size_t i = 0; i < nr_input_tensor; ++i) { | |||||
if (!cvprop.is_const(opr->input(i))) { | |||||
vars.insert(opr->input(i)); | |||||
} | |||||
} | |||||
} else { | |||||
size_t nr_input_tensor = | |||||
std::min(find->second, opr->input().size()); | |||||
for (size_t i = 0; i < nr_input_tensor; ++i) { | |||||
if (!cvprop.is_const(opr->input(i))) { | |||||
vars.insert(opr->input(i)); | |||||
} | } | ||||
} | } | ||||
vars.insert(opr->output(0)); | |||||
}; | |||||
DepOprIter iter{cb}; | |||||
for (auto&& i : problem.graph_partition().input()) { | |||||
iter.set_visited(i->owner_opr()); | |||||
} | } | ||||
for (auto&& o : problem.graph_partition().output()) { | |||||
iter.add(o->owner_opr()); | |||||
for (auto&& ov : opr->usable_output()) { | |||||
vars.insert(ov); | |||||
} | } | ||||
} | } | ||||
@@ -462,8 +472,14 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( | |||||
auto&& opr_configs = problem.opr_configs(); | auto&& opr_configs = problem.opr_configs(); | ||||
auto find = opr_configs.find(opr->dyn_typeinfo()); | auto find = opr_configs.find(opr->dyn_typeinfo()); | ||||
if (find == opr_configs.end()) { | if (find == opr_configs.end()) { | ||||
opr_record[opr] = profile_operator(opr, base_format, | |||||
available_tensor_formats); | |||||
if (skip_oprs.count(opr) > 0) { | |||||
SmallVector<TensorFormats> tensor_formats = {base_format}; | |||||
opr_record[opr] = | |||||
profile_operator(opr, base_format, tensor_formats); | |||||
} else { | |||||
opr_record[opr] = profile_operator(opr, base_format, | |||||
available_tensor_formats); | |||||
} | |||||
} else { | } else { | ||||
auto&& dispatchers = find->second; | auto&& dispatchers = find->second; | ||||
SmallVector<OprTensorFormatsConfiguration> configs; | SmallVector<OprTensorFormatsConfiguration> configs; | ||||
@@ -0,0 +1,56 @@ | |||||
/** | |||||
* \file src/gopt/impl/profiling_based_solver.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/gopt/global_layout_transform.h" | |||||
#include "megbrain/opr/dnn/pooling.h" | |||||
#include "megbrain/opr/imgproc.h" | |||||
using namespace mgb; | |||||
using namespace gopt; | |||||
using namespace opr; | |||||
/* =================== ProfilingBasedSolverSolver ======================*/ | |||||
ProfilingBasedSolver::ProfilingBasedSolver( | |||||
std::unique_ptr<ProfilerBase> profiler) | |||||
: m_profiler{std::move(profiler)} { | |||||
static const ThinHashSet<Typeinfo*> format_aware_oprs = { | |||||
#define cb(_Opr) _Opr::typeinfo() | |||||
cb(Convolution), | |||||
cb(ConvBiasForward), | |||||
cb(ConvolutionBackwardData), | |||||
cb(PoolingForward), | |||||
cb(WarpPerspective), | |||||
cb(Resize), | |||||
}; | |||||
m_graph_partition_filter = [](const GraphPartition& partition) { | |||||
bool has_format_aware_opr = false; | |||||
for (auto&& opr : partition.all_oprs()) { | |||||
if (!has_format_aware_opr && | |||||
format_aware_oprs.count(opr->dyn_typeinfo())) { | |||||
has_format_aware_opr = true; | |||||
break; | |||||
} | |||||
} | |||||
return has_format_aware_opr; | |||||
}; | |||||
} | |||||
ProfilingBasedSolver::Solution ProfilingBasedSolver::solve( | |||||
const Problem& problem) const { | |||||
const auto& partition = problem.graph_partition(); | |||||
if (!m_graph_partition_filter(partition)) | |||||
return Solution{}; | |||||
return do_solve(problem); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -11,9 +11,9 @@ | |||||
*/ | */ | ||||
#include "megbrain/gopt/reformat_manager.h" | #include "megbrain/gopt/reformat_manager.h" | ||||
#include "./utils.h" | |||||
#include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
#include "megbrain/utils/arith_helper.h" | #include "megbrain/utils/arith_helper.h" | ||||
#include "./utils.h" | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace gopt; | using namespace gopt; | ||||
@@ -87,21 +87,6 @@ bool ReformatManager::ReformatKey::Equal::operator()( | |||||
lhs.attribute == rhs.attribute; | lhs.attribute == rhs.attribute; | ||||
} | } | ||||
ReformatManager::ReformatKey& | |||||
ReformatManager::ReformatKey::deduce_reformat_dtype_enum(const DType& dt) { | |||||
static const ThinHashSet<std::pair<TensorFormats, TensorFormats>> set = { | |||||
{TensorFormats::NCHW, TensorFormats::NCHWc64}, | |||||
{TensorFormats::NCHWc64, TensorFormats::NCHW}, | |||||
{TensorFormats::NCHW, TensorFormats::NHWC}, | |||||
{TensorFormats::NHWC, TensorFormats::NCHW}}; | |||||
if (set.count({input_format, output_format}) > 0 && | |||||
(dt.enumv() == DTypeEnum::QuantizedS4 || | |||||
dt.enumv() == DTypeEnum::Quantized4Asymm)) { | |||||
input_dtype = output_dtype = dt.enumv(); | |||||
} | |||||
return *this; | |||||
} | |||||
// =================== ReformatManager ====================*/ | // =================== ReformatManager ====================*/ | ||||
ReformatManager::ReformatManager() { | ReformatManager::ReformatManager() { | ||||
using Attribute = ReformatKey::Attribute; | using Attribute = ReformatKey::Attribute; | ||||
@@ -378,7 +363,7 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_featrue( | |||||
divup(orig_channel, input_alignment) * input_alignment; | divup(orig_channel, input_alignment) * input_alignment; | ||||
size_t aligned_out_channel = | size_t aligned_out_channel = | ||||
divup(orig_channel, output_alignment) * output_alignment; | divup(orig_channel, output_alignment) * output_alignment; | ||||
size_t common_alignment = input_alignment * output_alignment / | |||||
size_t common_alignment = input_alignment * output_alignment / | |||||
gcd(input_alignment, output_alignment); | gcd(input_alignment, output_alignment); | ||||
size_t aligned_channel = | size_t aligned_channel = | ||||
divup(orig_channel, common_alignment) * common_alignment; | divup(orig_channel, common_alignment) * common_alignment; | ||||
@@ -427,11 +412,11 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||||
for (size_t i = 0; i < input_shape.ndim; ++i) { | for (size_t i = 0; i < input_shape.ndim; ++i) { | ||||
if (input_shape[i].name() == Dimension::Name::C && | if (input_shape[i].name() == Dimension::Name::C && | ||||
input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { | input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { | ||||
in_channels = orig_var->shape()[i]; | |||||
in_channels = orig_var->shape()[i] * input_shape[i].stride(); | |||||
input_channel_idx = i; | input_channel_idx = i; | ||||
mgb_assert(input_shape[i].stride() == 1, | |||||
"unsupport weight format(got:%s)", | |||||
input_shape.to_string().c_str()); | |||||
// mgb_assert(input_shape[i].stride() == 1, | |||||
// "unsupport weight format(got:%s)", | |||||
// input_shape.to_string().c_str()); | |||||
} else if ((input_shape[i].name() == Dimension::Name::K || | } else if ((input_shape[i].name() == Dimension::Name::K || | ||||
input_shape[i].name() == Dimension::Name::N) && | input_shape[i].name() == Dimension::Name::N) && | ||||
input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { | input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { | ||||
@@ -536,7 +521,8 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, | |||||
"formats(var:%s;shp:%s;fmt:%s)", | "formats(var:%s;shp:%s;fmt:%s)", | ||||
var->cname(), oshp.to_string().c_str(), | var->cname(), oshp.to_string().c_str(), | ||||
orig_shape.to_string().c_str()); | orig_shape.to_string().c_str()); | ||||
if (oshp.is_scalar()) return oshp; | |||||
if (oshp.is_scalar()) | |||||
return oshp; | |||||
TensorShape tshp; | TensorShape tshp; | ||||
ThinHashMap<Dimension::Name, int> name2dominant; | ThinHashMap<Dimension::Name, int> name2dominant; | ||||
for (size_t i = 0; i < orig_shape.ndim; ++i) { | for (size_t i = 0; i < orig_shape.ndim; ++i) { | ||||
@@ -597,4 +583,32 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, | |||||
return tshp; | return tshp; | ||||
} | } | ||||
ReformatManager::AlignmentDesc mgb::gopt::make_aligned_desc( | |||||
TensorFormats weight_format, TensorFormats out_feature_format) { | |||||
using AlignmentDesc = ReformatManager::AlignmentDesc; | |||||
using Name = Dimension::Name; | |||||
auto weight_shape = tensor_formats_to_named_tensor_shape(weight_format); | |||||
auto out_shape = tensor_formats_to_named_tensor_shape(out_feature_format); | |||||
size_t out_channel_alignment = 1; | |||||
for (size_t i = 0; i < out_shape.ndim; ++i) { | |||||
auto name = out_shape[i].name(); | |||||
auto extent = out_shape[i].extent(); | |||||
if ((name == Name::C || name == Name::K) && | |||||
extent == Dimension::UNDETERMINED_EXTENT) { | |||||
out_channel_alignment = out_shape[i].stride(); | |||||
break; | |||||
} | |||||
} | |||||
Name out_channel_name; | |||||
for (size_t i = 0; i < weight_shape.ndim; ++i) { | |||||
auto name = weight_shape[i].name(); | |||||
auto extent = weight_shape[i].extent(); | |||||
if ((name == Name::N || name == Name::K) && | |||||
extent == Dimension::UNDETERMINED_EXTENT) { | |||||
out_channel_name = name; | |||||
} | |||||
} | |||||
return AlignmentDesc{out_channel_name, out_channel_alignment}; | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -304,10 +304,15 @@ std::vector<GraphPartition> SubGraphExtractor::extract( | |||||
} | } | ||||
} | } | ||||
partition->opr_set().insert(opr); | partition->opr_set().insert(opr); | ||||
partition->all_oprs().push_back(opr); | |||||
for (const auto& i : opr->input()) | for (const auto& i : opr->input()) | ||||
partition->input().insert(i); | partition->input().insert(i); | ||||
} | } | ||||
} | } | ||||
for (auto&& partition : partitions) { | |||||
auto& all_oprs = partition.all_oprs(); | |||||
std::reverse(all_oprs.begin(), all_oprs.end()); | |||||
} | |||||
return partitions; | return partitions; | ||||
} | } | ||||
@@ -36,6 +36,28 @@ static inline const char* opr_format_to_string( | |||||
#undef cb | #undef cb | ||||
} | } | ||||
static inline TensorFormats opr_format_to_tensor_formats( | |||||
OprTensorFormatsConfiguration::OprFormat opr_format) { | |||||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||||
switch (opr_format) { | |||||
case OprFormat::NCHW: | |||||
return TensorFormats::NCHW; | |||||
case OprFormat::NHWC: | |||||
return TensorFormats::NHWC; | |||||
case OprFormat::NCHW4: | |||||
return TensorFormats::NCHWc4; | |||||
case OprFormat::NCHW32: | |||||
return TensorFormats::NCHWc32; | |||||
case OprFormat::NCHW64: | |||||
return TensorFormats::NCHWc64; | |||||
case OprFormat::CHWN4: | |||||
return TensorFormats::CHWNc4; | |||||
default: | |||||
mgb_throw(AssertionError, "format(%s) is not supported", | |||||
opr_format_to_string(opr_format)); | |||||
}; | |||||
} | |||||
static inline megdnn::NamedTensorShape tensor_formats_to_named_tensor_shape( | static inline megdnn::NamedTensorShape tensor_formats_to_named_tensor_shape( | ||||
TensorFormats format) { | TensorFormats format) { | ||||
switch (format) { | switch (format) { | ||||
@@ -11,6 +11,7 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megbrain/gopt/framework.h" | |||||
#include "megbrain/gopt/reformat_manager.h" | #include "megbrain/gopt/reformat_manager.h" | ||||
#include "megbrain/gopt/subgraph_extractor.h" | #include "megbrain/gopt/subgraph_extractor.h" | ||||
#include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
@@ -41,14 +42,16 @@ struct OprTensorFormatsConfiguration { | |||||
/*! | /*! | ||||
* \brief A structure that describes the global layout transform problem | * \brief A structure that describes the global layout transform problem | ||||
*/ | */ | ||||
class Problem { | |||||
class LayoutTransformContext { | |||||
public: | public: | ||||
using OprList = SubGraphExtractor::OprList; | |||||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | using OprFormat = OprTensorFormatsConfiguration::OprFormat; | ||||
using OprTensorFormatsDispatcher = | using OprTensorFormatsDispatcher = | ||||
OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; | OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; | ||||
using OprConfigTrait = | using OprConfigTrait = | ||||
ThinHashMap<Typeinfo*, | ThinHashMap<Typeinfo*, | ||||
ThinHashMap<OprFormat, OprTensorFormatsDispatcher*>>; | ThinHashMap<OprFormat, OprTensorFormatsDispatcher*>>; | ||||
using ReformatAttribute = ReformatManager::ReformatKey::Attribute; | |||||
struct Attribute { | struct Attribute { | ||||
OprFormat base_opr_format; /// the base opr format indicates that the | OprFormat base_opr_format; /// the base opr format indicates that the | ||||
/// network to be optimized is constructed | /// network to be optimized is constructed | ||||
@@ -62,58 +65,110 @@ public: | |||||
/// (like elemwise, elemwise multi type, | /// (like elemwise, elemwise multi type, | ||||
/// typecvt etc.) are built in the base | /// typecvt etc.) are built in the base | ||||
/// tensor format. | /// tensor format. | ||||
ReformatAttribute | |||||
reformat_attribute; /// additional reformat attribute, which | |||||
/// indicates whether to pad nhwc layout | |||||
/// automatically or to enable nhwcd4 format | |||||
/// on opencl platform to use image object | |||||
}; | }; | ||||
Problem(const GraphPartition& graph_partition, | |||||
const SmallVector<TensorFormats>& available_tensor_formats, | |||||
const OprConfigTrait& opr_config, const Attribute& attribute) | |||||
: m_graph_partition{graph_partition}, | |||||
m_available_tensor_formats{available_tensor_formats}, | |||||
m_opr_configs{opr_config}, | |||||
LayoutTransformContext() = delete; | |||||
LayoutTransformContext(OprList opr_list, | |||||
SmallVector<TensorFormats> available_tensor_formats, | |||||
Attribute attribute) | |||||
: m_opr_list{std::move(opr_list)}, | |||||
m_available_tensor_formats{std::move(available_tensor_formats)}, | |||||
m_attribute{attribute} {} | |||||
LayoutTransformContext(OprList opr_list, | |||||
SmallVector<TensorFormats> available_tensor_formats, | |||||
OprConfigTrait opr_configs, Attribute attribute) | |||||
: m_opr_list{std::move(opr_list)}, | |||||
m_available_tensor_formats{std::move(available_tensor_formats)}, | |||||
m_opr_configs{std::move(opr_configs)}, | |||||
m_attribute{attribute} {} | m_attribute{attribute} {} | ||||
const OprList& opr_list() const { return m_opr_list; } | |||||
const SmallVector<TensorFormats>& available_tensor_formats() const { | |||||
return m_available_tensor_formats; | |||||
} | |||||
const OprConfigTrait& opr_configs() const { return m_opr_configs; } | |||||
Attribute attribute() const { return m_attribute; } | |||||
/*! | |||||
* \brief add an op format configuration for a particular operator type | |||||
* \param opr runtime typeinfo of operator | |||||
* \param opr_format op format configuration which to be enabled in the | |||||
* layout transform problem | |||||
*/ | |||||
LayoutTransformContext& add_opr_config(Typeinfo* opr, OprFormat opr_format); | |||||
/*! | |||||
* \brief add a vector of op format configurations for a particular operator | |||||
* type | |||||
* \param opr runtime typeinfo of operator | |||||
* \param opr_format op format configuration which to be enabled in the | |||||
* layout transform problem | |||||
*/ | |||||
LayoutTransformContext& add_opr_config(Typeinfo* opr, | |||||
SmallVector<OprFormat> opr_formats); | |||||
private: | |||||
OprList m_opr_list; /// supported operator list | |||||
SmallVector<TensorFormats> | |||||
m_available_tensor_formats; /// the available tensor formats, used | |||||
/// for format agnostic operators (like | |||||
/// elemwise, elemwise multi type, | |||||
/// typecvt, etc. | |||||
OprConfigTrait m_opr_configs; /// the available opr format configurations, | |||||
/// used for format aware operators (like | |||||
/// conv, deconv, conv_bias, etc. | |||||
Attribute m_attribute; /// the extra attributes to describe the problem | |||||
}; | |||||
class Problem { | |||||
public: | |||||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||||
using OprTensorFormatsDispatcher = | |||||
OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; | |||||
using OprConfigTrait = LayoutTransformContext::OprConfigTrait; | |||||
using Attribute = LayoutTransformContext::Attribute; | |||||
Problem(const GraphPartition& graph_partition, | |||||
const LayoutTransformContext& ctx) | |||||
: m_graph_partition{graph_partition}, m_ctx{ctx} {} | |||||
~Problem() noexcept = default; | ~Problem() noexcept = default; | ||||
const GraphPartition& graph_partition() const { return m_graph_partition; } | const GraphPartition& graph_partition() const { return m_graph_partition; } | ||||
const OprConfigTrait& opr_configs() const { return m_opr_configs; } | |||||
const OprConfigTrait& opr_configs() const { return m_ctx.opr_configs(); } | |||||
const SmallVector<TensorFormats>& available_tensor_formats() const { | const SmallVector<TensorFormats>& available_tensor_formats() const { | ||||
return m_available_tensor_formats; | |||||
return m_ctx.available_tensor_formats(); | |||||
} | } | ||||
TensorFormats base_format() const { | TensorFormats base_format() const { | ||||
return m_attribute.base_tensor_formats; | |||||
return m_ctx.attribute().base_tensor_formats; | |||||
} | } | ||||
/*! | |||||
* \brief return the tensor formats configuration of an operator in the | |||||
* default op format | |||||
*/ | |||||
OprTensorFormatsConfiguration base_config( | OprTensorFormatsConfiguration base_config( | ||||
const cg::OperatorNodeBase* opr) const { | const cg::OperatorNodeBase* opr) const { | ||||
auto _ = OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | auto _ = OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | ||||
opr->dyn_typeinfo(), m_attribute.base_opr_format); | |||||
opr->dyn_typeinfo(), m_ctx.attribute().base_opr_format); | |||||
auto rst = (*_)(opr); | auto rst = (*_)(opr); | ||||
if (rst.valid()) | if (rst.valid()) | ||||
return rst.val(); | return rst.val(); | ||||
OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
config.opr_format = m_attribute.base_opr_format; | |||||
config.opr_format = m_ctx.attribute().base_opr_format; | |||||
for (const auto& i : opr->input()) { | for (const auto& i : opr->input()) { | ||||
config.input_dtypes.emplace_back(i->dtype().enumv()); | config.input_dtypes.emplace_back(i->dtype().enumv()); | ||||
config.input_tensor_formats.emplace_back( | |||||
m_attribute.base_tensor_formats); | |||||
config.input_tensor_formats.emplace_back(base_format()); | |||||
config.input_tensor_types.emplace_back(TensorType::FEATURE); | config.input_tensor_types.emplace_back(TensorType::FEATURE); | ||||
} | } | ||||
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | ||||
config.output_tensor_formats.emplace_back( | |||||
m_attribute.base_tensor_formats); | |||||
config.output_tensor_formats.emplace_back(base_format()); | |||||
return config; | return config; | ||||
} | } | ||||
private: | private: | ||||
const GraphPartition& m_graph_partition; /// the graph partition | const GraphPartition& m_graph_partition; /// the graph partition | ||||
const SmallVector<TensorFormats>& | |||||
m_available_tensor_formats; /// the available tensor formats, used | |||||
/// for format agnostic operators (like | |||||
/// elemwise, elemwise multi type, | |||||
/// typecvt, etc. | |||||
const OprConfigTrait& | |||||
m_opr_configs; /// the available opr format configurations, used | |||||
/// for format aware operators (like conv, deconv, | |||||
/// conv_bias, etc. | |||||
Attribute m_attribute; /// the extra attributes to describe the problem | |||||
const LayoutTransformContext& m_ctx; | |||||
}; | }; | ||||
/*! | /*! | ||||
@@ -170,6 +225,92 @@ public: | |||||
static std::unique_ptr<ProfilerBase> make_profiler(); | static std::unique_ptr<ProfilerBase> make_profiler(); | ||||
}; | }; | ||||
/*! | |||||
* \brief abstract solver | |||||
*/ | |||||
class SolverBase { | |||||
public: | |||||
using OprFormat = Problem::OprFormat; | |||||
using Solution = ThinHashMap<cg::OperatorNodeBase*, OprFormat>; | |||||
SolverBase() = default; | |||||
virtual ~SolverBase() = default; | |||||
/*! | |||||
* \brief solve the given problem | |||||
*/ | |||||
virtual Solution solve(const Problem& problem) const = 0; | |||||
/*! | |||||
* \brief check whether the given problem can be solved by the | |||||
* algorithm(i.e. solver). | |||||
*/ | |||||
virtual bool can_solve(const Problem& problem) const = 0; | |||||
}; | |||||
/*! | |||||
* \brief solvers that will first collect the costs of operators in different op | |||||
* format and the costs of layout transform of varnode with a user provided | |||||
* profiler on the target device. This will lead to time consuming. | |||||
*/ | |||||
class ProfilingBasedSolver : public SolverBase { | |||||
public: | |||||
using GraphPartitionFilter = | |||||
thin_function<bool(const GraphPartition& graph_partition)>; | |||||
ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler); | |||||
/*! | |||||
* \note some graph partition (for example, graph partition without format | |||||
* aware operators like conv, deconv, warp, resize etc.) will be filtered by | |||||
* the GraphPartitionFilter, which can reduce the profiling time. */ | |||||
ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler, | |||||
GraphPartitionFilter graph_partition_filter) | |||||
: m_profiler{std::move(profiler)}, | |||||
m_graph_partition_filter{std::move(graph_partition_filter)} {} | |||||
virtual ~ProfilingBasedSolver() = default; | |||||
Solution solve(const Problem& problem) const override; | |||||
virtual Solution do_solve(const Problem& problem) const = 0; | |||||
protected: | |||||
std::unique_ptr<ProfilerBase> m_profiler; | |||||
private: | |||||
GraphPartitionFilter m_graph_partition_filter; | |||||
}; | |||||
/*! | |||||
* \brief A solver that solves the layout selection problem using dynamic | |||||
* programming algorithm (Markov decision process). | |||||
*/ | |||||
class DynamicProgrammingSolver final : public ProfilingBasedSolver { | |||||
public: | |||||
DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler) | |||||
: ProfilingBasedSolver(std::move(profiler)){}; | |||||
DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler, | |||||
GraphPartitionFilter graph_partition_filter) | |||||
: ProfilingBasedSolver(std::move(profiler), | |||||
std::move(graph_partition_filter)){}; | |||||
~DynamicProgrammingSolver() noexcept = default; | |||||
Solution do_solve(const Problem& problem) const override; | |||||
bool can_solve(const Problem& problem) const override; | |||||
private: | |||||
class Impl; | |||||
}; | |||||
/*! | |||||
* \brief A layout transform pass, which convert the operator's format to the | |||||
* optimal format using the results of the solver. | |||||
*/ | |||||
class LayoutTransformPass final : public Pass { | |||||
public: | |||||
const char* name() const override { return "layout assignment pass"; } | |||||
void apply(OptState& opt) const override; | |||||
LayoutTransformPass(std::unique_ptr<LayoutTransformContext> ctx, | |||||
std::unique_ptr<SolverBase> solver) | |||||
: m_ctx{std::move(ctx)}, m_solver{std::move(solver)} {} | |||||
private: | |||||
std::unique_ptr<LayoutTransformContext> m_ctx; | |||||
std::unique_ptr<SolverBase> m_solver; | |||||
}; | |||||
} // namespace gopt | } // namespace gopt | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -84,7 +84,7 @@ public: | |||||
output_dtype{DTypeEnum::Float32}, | output_dtype{DTypeEnum::Float32}, | ||||
attribute{Attribute::DEFAULT} {} | attribute{Attribute::DEFAULT} {} | ||||
ReformatKey(TensorFormats input_format_, TensorFormats output_format_, | ReformatKey(TensorFormats input_format_, TensorFormats output_format_, | ||||
Attribute attribute_ = Attribute::DEFAULT, | |||||
Attribute attribute_, | |||||
DTypeEnum input_dtype_ = DTypeEnum::Float32, | DTypeEnum input_dtype_ = DTypeEnum::Float32, | ||||
DTypeEnum output_dtype_ = DTypeEnum::Float32) | DTypeEnum output_dtype_ = DTypeEnum::Float32) | ||||
: input_format{input_format_}, | : input_format{input_format_}, | ||||
@@ -92,6 +92,15 @@ public: | |||||
input_dtype{input_dtype_}, | input_dtype{input_dtype_}, | ||||
output_dtype{output_dtype_}, | output_dtype{output_dtype_}, | ||||
attribute{attribute_} {} | attribute{attribute_} {} | ||||
ReformatKey(TensorFormats input_format_, TensorFormats output_format_, | |||||
DTypeEnum input_dtype_ = DTypeEnum::Float32, | |||||
DTypeEnum output_dtype_ = DTypeEnum::Float32, | |||||
Attribute attribute_ = Attribute::DEFAULT) | |||||
: input_format{input_format_}, | |||||
output_format{output_format_}, | |||||
input_dtype{input_dtype_}, | |||||
output_dtype{output_dtype_}, | |||||
attribute{attribute_} {} | |||||
struct Hash { | struct Hash { | ||||
size_t operator()(const ReformatKey& key) const; | size_t operator()(const ReformatKey& key) const; | ||||
}; | }; | ||||
@@ -99,7 +108,6 @@ public: | |||||
bool operator()(const ReformatKey& lhs, | bool operator()(const ReformatKey& lhs, | ||||
const ReformatKey& rhs) const; | const ReformatKey& rhs) const; | ||||
}; | }; | ||||
ReformatKey& deduce_reformat_dtype_enum(const DType& dt); | |||||
}; | }; | ||||
using ReformatCache = | using ReformatCache = | ||||
std::unordered_map<ReformatKey, ReformatImpl, ReformatKey::Hash, | std::unordered_map<ReformatKey, ReformatImpl, ReformatKey::Hash, | ||||
@@ -130,6 +138,9 @@ TensorShape make_aligned_weight_shape(const VarNode* var, | |||||
TensorFormats orig_formats, | TensorFormats orig_formats, | ||||
TensorFormats target_formats, | TensorFormats target_formats, | ||||
TensorFormats extra_formats); | TensorFormats extra_formats); | ||||
ReformatManager::AlignmentDesc make_aligned_desc( | |||||
TensorFormats weight_format, TensorFormats out_feature_format); | |||||
} // namespace gopt | } // namespace gopt | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -20,6 +20,7 @@ class GraphPartition { | |||||
public: | public: | ||||
using VarNodeSet = ThinHashSet<VarNode*>; | using VarNodeSet = ThinHashSet<VarNode*>; | ||||
using OperatorNodeSet = ThinHashSet<cg::OperatorNodeBase*>; | using OperatorNodeSet = ThinHashSet<cg::OperatorNodeBase*>; | ||||
using OperatorNodeList = std::vector<cg::OperatorNodeBase*>; | |||||
class InputPlaceholder; | class InputPlaceholder; | ||||
@@ -32,15 +33,18 @@ public: | |||||
const OperatorNodeSet& opr_set() const { return m_opr_set; } | const OperatorNodeSet& opr_set() const { return m_opr_set; } | ||||
const VarNodeSet& input() const { return m_inputs; } | const VarNodeSet& input() const { return m_inputs; } | ||||
const VarNodeSet& output() const { return m_outputs; } | const VarNodeSet& output() const { return m_outputs; } | ||||
const OperatorNodeList& all_oprs() const { return m_oprs; } | |||||
OperatorNodeSet& opr_set() { return m_opr_set; } | OperatorNodeSet& opr_set() { return m_opr_set; } | ||||
OperatorNodeList& all_oprs() { return m_oprs; } | |||||
VarNodeSet& input() { return m_inputs; } | VarNodeSet& input() { return m_inputs; } | ||||
VarNodeSet& output() { return m_outputs; } | VarNodeSet& output() { return m_outputs; } | ||||
private: | private: | ||||
std::pair<VarNodeArray, VarNodeArray> replace_graph_by_placeholder() const; | |||||
OperatorNodeSet m_opr_set; | OperatorNodeSet m_opr_set; | ||||
OperatorNodeList m_oprs; | |||||
VarNodeSet m_inputs; | VarNodeSet m_inputs; | ||||
VarNodeSet m_outputs; | VarNodeSet m_outputs; | ||||
std::pair<VarNodeArray, VarNodeArray> replace_graph_by_placeholder() const; | |||||
}; | }; | ||||
class SubGraphExtractor { | class SubGraphExtractor { | ||||
@@ -10,6 +10,7 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "megbrain/plugin/profiler.h" | |||||
#include "./helper.h" | #include "./helper.h" | ||||
#include "megbrain/gopt/global_layout_transform.h" | #include "megbrain/gopt/global_layout_transform.h" | ||||
#include "megbrain/gopt/inference.h" | #include "megbrain/gopt/inference.h" | ||||
@@ -22,123 +23,59 @@ using namespace mgb; | |||||
using namespace gopt; | using namespace gopt; | ||||
using namespace serialization; | using namespace serialization; | ||||
#if MGB_CUDA | |||||
namespace { | namespace { | ||||
class LayoutTransformContext : public NonCopyableObj { | |||||
public: | |||||
using OprList = SubGraphExtractor::OprList; | |||||
using OprFormat = Problem::OprFormat; | |||||
using OprConfigTrait = Problem::OprConfigTrait; | |||||
LayoutTransformContext() = delete; | |||||
LayoutTransformContext(OprList opr_list, | |||||
SmallVector<TensorFormats> available_tensor_formats, | |||||
OprConfigTrait opr_configs) | |||||
: m_opr_list{std::move(opr_list)}, | |||||
m_available_tensor_formats{std::move(available_tensor_formats)}, | |||||
m_opr_configs{std::move(opr_configs)} {} | |||||
const OprList& opr_list() const { return m_opr_list; } | |||||
const SmallVector<TensorFormats>& available_tensor_formats() const { | |||||
return m_available_tensor_formats; | |||||
} | |||||
const OprConfigTrait& opr_configs() const { return m_opr_configs; } | |||||
static std::unique_ptr<LayoutTransformContext> make() { | |||||
OprList opr_list = { | |||||
opr::ConvBiasForward::typeinfo(), | |||||
opr::ConvolutionForward::typeinfo(), | |||||
opr::ConvolutionBackwardData::typeinfo(), | |||||
opr::ElemwiseMultiType::typeinfo(), | |||||
opr::Elemwise::typeinfo(), | |||||
opr::TypeCvt::typeinfo(), | |||||
opr::PoolingForward::typeinfo(), | |||||
opr::WarpPerspectiveForward::typeinfo(), | |||||
}; | |||||
OprConfigTrait opr_configs; | |||||
{ | |||||
auto& dispatchers = opr_configs[opr::ConvBias::typeinfo()]; | |||||
#define cb(_fmt) \ | |||||
dispatchers[OprFormat::_fmt] = \ | |||||
OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ | |||||
opr::ConvBias::typeinfo(), OprFormat::_fmt); | |||||
cb(NCHW4); | |||||
cb(NCHW32); | |||||
cb(NHWC); | |||||
cb(NCHW64); | |||||
cb(CHWN4); | |||||
#undef cb | |||||
} | |||||
{ | |||||
auto& dispatchers = | |||||
opr_configs[opr::ConvolutionBackwardData::typeinfo()]; | |||||
#define cb(_fmt) \ | |||||
dispatchers[OprFormat::_fmt] = \ | |||||
OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ | |||||
opr::ConvolutionBackwardData::typeinfo(), \ | |||||
OprFormat::_fmt); | |||||
cb(NCHW4); | |||||
#undef cb | |||||
} | |||||
{ | |||||
auto& dispatchers = | |||||
opr_configs[opr::ConvolutionForward::typeinfo()]; | |||||
#define cb(_fmt) \ | |||||
dispatchers[OprFormat::_fmt] = \ | |||||
OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ | |||||
opr::ConvolutionForward::typeinfo(), OprFormat::_fmt); | |||||
cb(NCHW4); | |||||
#undef cb | |||||
} | |||||
{ | |||||
auto& dispatchers = opr_configs[opr::PoolingForward::typeinfo()]; | |||||
#define cb(_fmt) \ | |||||
dispatchers[OprFormat::_fmt] = \ | |||||
OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ | |||||
opr::PoolingForward::typeinfo(), OprFormat::_fmt); | |||||
cb(NCHW4); | |||||
cb(NCHW32); | |||||
cb(NHWC); | |||||
cb(NCHW64); | |||||
cb(CHWN4); | |||||
#undef cb | |||||
} | |||||
{ | |||||
auto& dispatchers = | |||||
opr_configs[opr::WarpPerspectiveForward::typeinfo()]; | |||||
#define cb(_fmt) \ | |||||
dispatchers[OprFormat::_fmt] = \ | |||||
OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ | |||||
opr::WarpPerspectiveForward::typeinfo(), OprFormat::_fmt); | |||||
cb(NHWC); | |||||
cb(NCHW4); | |||||
cb(NCHW64); | |||||
#undef cb | |||||
} | |||||
SmallVector<TensorFormats> available_tensor_formats = { | |||||
TensorFormats::NHWC, TensorFormats::NCHWc4, | |||||
TensorFormats::NCHWc32, TensorFormats::NCHWc64}; | |||||
return std::make_unique<LayoutTransformContext>( | |||||
std::move(opr_list), std::move(available_tensor_formats), | |||||
std::move(opr_configs)); | |||||
} | |||||
std::unique_ptr<LayoutTransformContext> make_ctx() { | |||||
using OprFormat = LayoutTransformContext::OprFormat; | |||||
using OprList = LayoutTransformContext::OprList; | |||||
using ReformatAttribute = LayoutTransformContext::ReformatAttribute; | |||||
using Attribute = LayoutTransformContext::Attribute; | |||||
OprList opr_list = { | |||||
opr::ConvBiasForward::typeinfo(), | |||||
opr::ConvolutionForward::typeinfo(), | |||||
opr::ConvolutionBackwardData::typeinfo(), | |||||
opr::ElemwiseMultiType::typeinfo(), | |||||
opr::Elemwise::typeinfo(), | |||||
opr::TypeCvt::typeinfo(), | |||||
opr::PoolingForward::typeinfo(), | |||||
opr::WarpPerspectiveForward::typeinfo(), | |||||
}; | |||||
private: | |||||
OprList m_opr_list; | |||||
SmallVector<TensorFormats> m_available_tensor_formats; | |||||
OprConfigTrait m_opr_configs; | |||||
}; | |||||
}; // namespace | |||||
SmallVector<TensorFormats> available_tensor_formats = { | |||||
TensorFormats::NCHW, TensorFormats::NHWC, | |||||
TensorFormats::NCHWc4, TensorFormats::NCHWc32, | |||||
TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | |||||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, | |||||
ReformatAttribute::DEFAULT}; | |||||
auto ctx = std::make_unique<LayoutTransformContext>( | |||||
std::move(opr_list), std::move(available_tensor_formats), | |||||
attribute); | |||||
ctx->add_opr_config( | |||||
opr::ConvBiasForward::typeinfo(), | |||||
{OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, | |||||
OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) | |||||
.add_opr_config(opr::ConvolutionForward::typeinfo(), | |||||
{OprFormat::NCHW, OprFormat::NCHW4}) | |||||
.add_opr_config(opr::ConvolutionBackwardData::typeinfo(), | |||||
{OprFormat::NCHW, OprFormat::NCHW4}) | |||||
.add_opr_config( | |||||
opr::PoolingForward::typeinfo(), | |||||
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | |||||
OprFormat::NCHW64, OprFormat::CHWN4}) | |||||
.add_opr_config( | |||||
opr::WarpPerspectiveForward::typeinfo(), | |||||
{OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); | |||||
return ctx; | |||||
} | |||||
} // namespace | |||||
#if MGB_CUDA | |||||
#if CUDA_VERSION >= 10020 | #if CUDA_VERSION >= 10020 | ||||
TEST(TestProfiler, Conv) { | TEST(TestProfiler, Conv) { | ||||
REQUIRE_GPU(1); | REQUIRE_GPU(1); | ||||
auto cn = CompNode::load("gpu0"); | auto cn = CompNode::load("gpu0"); | ||||
cn.activate(); | cn.activate(); | ||||
REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); | REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); | ||||
auto ctx = LayoutTransformContext::make(); | |||||
auto ctx = make_ctx(); | |||||
HostTensorGenerator<dtype::Int8> gen; | HostTensorGenerator<dtype::Int8> gen; | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
@@ -177,14 +114,10 @@ TEST(TestProfiler, Conv) { | |||||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | ||||
S strategy = S::PROFILE; | S strategy = S::PROFILE; | ||||
gopt::modify_opr_algo_strategy_inplace({c2}, strategy); | gopt::modify_opr_algo_strategy_inplace({c2}, strategy); | ||||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||||
SubGraphExtractor extractor(ctx->opr_list()); | SubGraphExtractor extractor(ctx->opr_list()); | ||||
auto partitions = extractor.extract({c2}); | auto partitions = extractor.extract({c2}); | ||||
ASSERT_EQ(partitions.size(), 1u); | ASSERT_EQ(partitions.size(), 1u); | ||||
using Attribute = Problem::Attribute; | |||||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; | |||||
Problem problem(partitions[0], ctx->available_tensor_formats(), | |||||
ctx->opr_configs(), attribute); | |||||
Problem problem(partitions[0], *ctx); | |||||
auto profiler = ProfilerBase::make_profiler(); | auto profiler = ProfilerBase::make_profiler(); | ||||
auto rst = profiler->profile(problem); | auto rst = profiler->profile(problem); | ||||
const auto& opr_rst = rst.opr_record; | const auto& opr_rst = rst.opr_record; | ||||
@@ -204,7 +137,7 @@ TEST(TestProfiler, Deconv) { | |||||
auto cn = CompNode::load("gpu0"); | auto cn = CompNode::load("gpu0"); | ||||
cn.activate(); | cn.activate(); | ||||
REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); | REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); | ||||
auto ctx = LayoutTransformContext::make(); | |||||
auto ctx = make_ctx(); | |||||
HostTensorGenerator<dtype::Int8> gen; | HostTensorGenerator<dtype::Int8> gen; | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
@@ -238,14 +171,10 @@ TEST(TestProfiler, Deconv) { | |||||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | ||||
S strategy = S::PROFILE; | S strategy = S::PROFILE; | ||||
gopt::modify_opr_algo_strategy_inplace({c2}, strategy); | gopt::modify_opr_algo_strategy_inplace({c2}, strategy); | ||||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||||
SubGraphExtractor extractor(ctx->opr_list()); | SubGraphExtractor extractor(ctx->opr_list()); | ||||
auto partitions = extractor.extract({c2}); | auto partitions = extractor.extract({c2}); | ||||
ASSERT_EQ(partitions.size(), 1u); | ASSERT_EQ(partitions.size(), 1u); | ||||
using Attribute = Problem::Attribute; | |||||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; | |||||
Problem problem(partitions[0], ctx->available_tensor_formats(), | |||||
ctx->opr_configs(), attribute); | |||||
Problem problem(partitions[0], *ctx); | |||||
auto profiler = ProfilerBase::make_profiler(); | auto profiler = ProfilerBase::make_profiler(); | ||||
auto rst = profiler->profile(problem); | auto rst = profiler->profile(problem); | ||||
const auto& opr_rst = rst.opr_record; | const auto& opr_rst = rst.opr_record; | ||||
@@ -262,7 +191,7 @@ TEST(TestProfiler, Warp) { | |||||
auto cn = CompNode::load("gpu0"); | auto cn = CompNode::load("gpu0"); | ||||
cn.activate(); | cn.activate(); | ||||
REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); | REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); | ||||
auto ctx = LayoutTransformContext::make(); | |||||
auto ctx = make_ctx(); | |||||
constexpr size_t INP_H = 10, INP_W = 10, N = 16; | constexpr size_t INP_H = 10, INP_W = 10, N = 16; | ||||
@@ -307,14 +236,9 @@ TEST(TestProfiler, Warp) { | |||||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | ||||
S strategy = S::PROFILE; | S strategy = S::PROFILE; | ||||
gopt::modify_opr_algo_strategy_inplace({w1}, strategy); | gopt::modify_opr_algo_strategy_inplace({w1}, strategy); | ||||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||||
SubGraphExtractor extractor(ctx->opr_list()); | SubGraphExtractor extractor(ctx->opr_list()); | ||||
auto partitions = extractor.extract({w1}); | auto partitions = extractor.extract({w1}); | ||||
ASSERT_EQ(partitions.size(), 1u); | |||||
using Attribute = Problem::Attribute; | |||||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; | |||||
Problem problem(partitions[0], ctx->available_tensor_formats(), | |||||
ctx->opr_configs(), attribute); | |||||
Problem problem(partitions[0], *ctx); | |||||
auto profiler = ProfilerBase::make_profiler(); | auto profiler = ProfilerBase::make_profiler(); | ||||
auto rst = profiler->profile(problem); | auto rst = profiler->profile(problem); | ||||
const auto& opr_rst = rst.opr_record; | const auto& opr_rst = rst.opr_record; | ||||
@@ -330,7 +254,7 @@ TEST(TestProfiler, Pooling) { | |||||
auto cn = CompNode::load("gpu0"); | auto cn = CompNode::load("gpu0"); | ||||
cn.activate(); | cn.activate(); | ||||
REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); | REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); | ||||
auto ctx = LayoutTransformContext::make(); | |||||
auto ctx = make_ctx(); | |||||
HostTensorGenerator<dtype::Int8> gen; | HostTensorGenerator<dtype::Int8> gen; | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
@@ -353,14 +277,10 @@ TEST(TestProfiler, Pooling) { | |||||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | ||||
S strategy = S::PROFILE; | S strategy = S::PROFILE; | ||||
gopt::modify_opr_algo_strategy_inplace({p2}, strategy); | gopt::modify_opr_algo_strategy_inplace({p2}, strategy); | ||||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||||
SubGraphExtractor extractor(ctx->opr_list()); | SubGraphExtractor extractor(ctx->opr_list()); | ||||
auto partitions = extractor.extract({p2}); | auto partitions = extractor.extract({p2}); | ||||
ASSERT_EQ(partitions.size(), 1u); | ASSERT_EQ(partitions.size(), 1u); | ||||
using Attribute = Problem::Attribute; | |||||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; | |||||
Problem problem(partitions[0], ctx->available_tensor_formats(), | |||||
ctx->opr_configs(), attribute); | |||||
Problem problem(partitions[0], *ctx); | |||||
auto profiler = ProfilerBase::make_profiler(); | auto profiler = ProfilerBase::make_profiler(); | ||||
auto rst = profiler->profile(problem); | auto rst = profiler->profile(problem); | ||||
const auto& opr_rst = rst.opr_record; | const auto& opr_rst = rst.opr_record; | ||||
@@ -373,8 +293,7 @@ TEST(TestProfiler, Elemwise) { | |||||
REQUIRE_GPU(1); | REQUIRE_GPU(1); | ||||
auto cn = CompNode::load("gpu0"); | auto cn = CompNode::load("gpu0"); | ||||
cn.activate(); | cn.activate(); | ||||
REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); | |||||
auto ctx = LayoutTransformContext::make(); | |||||
auto ctx = make_ctx(); | |||||
HostTensorGenerator<dtype::Int8> gen; | HostTensorGenerator<dtype::Int8> gen; | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
@@ -403,14 +322,10 @@ TEST(TestProfiler, Elemwise) { | |||||
OperatorNodeConfig( | OperatorNodeConfig( | ||||
dtype::Quantized4Asymm(13.f, static_cast<uint8_t>(4)))); | dtype::Quantized4Asymm(13.f, static_cast<uint8_t>(4)))); | ||||
using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||||
SubGraphExtractor extractor(ctx->opr_list()); | SubGraphExtractor extractor(ctx->opr_list()); | ||||
auto partitions = extractor.extract({q4e}); | auto partitions = extractor.extract({q4e}); | ||||
ASSERT_EQ(partitions.size(), 1u); | ASSERT_EQ(partitions.size(), 1u); | ||||
using Attribute = Problem::Attribute; | |||||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; | |||||
Problem problem(partitions[0], ctx->available_tensor_formats(), | |||||
ctx->opr_configs(), attribute); | |||||
Problem problem(partitions[0], *ctx); | |||||
auto profiler = ProfilerBase::make_profiler(); | auto profiler = ProfilerBase::make_profiler(); | ||||
auto rst = profiler->profile(problem); | auto rst = profiler->profile(problem); | ||||
const auto& opr_rst = rst.opr_record; | const auto& opr_rst = rst.opr_record; | ||||
@@ -423,7 +338,6 @@ TEST(TestProfiler, Elemwise) { | |||||
EXPECT_TRUE(var_rst.count(q8a.node()) > 0); | EXPECT_TRUE(var_rst.count(q8a.node()) > 0); | ||||
EXPECT_TRUE(var_rst.count(q8b.node()) > 0); | EXPECT_TRUE(var_rst.count(q8b.node()) > 0); | ||||
} | } | ||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -447,6 +447,7 @@ TEST(TestReformatManager, AutoAlignedFeatureProfiling) { | |||||
for (size_t i = 0; i < RUNS; ++i) | for (size_t i = 0; i < RUNS; ++i) | ||||
func->execute(); | func->execute(); | ||||
double time_profiler = profiler->duration() * 1e6; | double time_profiler = profiler->duration() * 1e6; | ||||
printf("time: %f, %f\n", time_cuda_evt, time_profiler); | |||||
MGB_CUDA_CHECK(cudaEventDestroy(evt0)); | MGB_CUDA_CHECK(cudaEventDestroy(evt0)); | ||||
MGB_CUDA_CHECK(cudaEventDestroy(evt1)); | MGB_CUDA_CHECK(cudaEventDestroy(evt1)); | ||||
} | } | ||||