@@ -260,3 +260,15 @@ def replace_oprs(dst, oprmap): | |||||
repl_dst_vec.push_back(j) | repl_dst_vec.push_back(j) | ||||
return _mgb._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) | return _mgb._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) | ||||
def set_priority_to_id(dest_vars): | |||||
"""For all oprs in the subgraph constructed by dest_vars | |||||
set its priority to id if its original priority is zero | |||||
:param dest_vars: target vars representing the graph | |||||
""" | |||||
dest_vec = _mgb._VectorSymbolVar() | |||||
for i in dest_vars: | |||||
assert isinstance(i, _mgb.SymbolVar) | |||||
dest_vec.push_back(i) | |||||
_mgb._set_priority_to_id(dest_vec) |
@@ -84,6 +84,8 @@ class trace: | |||||
:param log_level: Log level. | :param log_level: Log level. | ||||
:param sublinear_memory_config: Configuration for sublinear memory optimization. | :param sublinear_memory_config: Configuration for sublinear memory optimization. | ||||
If not None, it enables sublinear memory optimization with given setting. | If not None, it enables sublinear memory optimization with given setting. | ||||
:param allreduce_pack_max_size: Maximum size of an allreduce pack in MB. | |||||
If not None, multiple gradients will be packed and synchronized together | |||||
:param profiling: Whether to profile compiled trace. Default: False | :param profiling: Whether to profile compiled trace. Default: False | ||||
""" | """ | ||||
@@ -107,6 +109,7 @@ class trace: | |||||
opt_level: int = None, | opt_level: int = None, | ||||
log_level: int = None, | log_level: int = None, | ||||
sublinear_memory_config: SublinearMemoryConfig = None, | sublinear_memory_config: SublinearMemoryConfig = None, | ||||
allreduce_pack_max_size: int = None, | |||||
profiling: bool = False | profiling: bool = False | ||||
): | ): | ||||
self.__wrapped__ = func | self.__wrapped__ = func | ||||
@@ -114,6 +117,7 @@ class trace: | |||||
self._graph_opt_level = opt_level | self._graph_opt_level = opt_level | ||||
self._log_level = log_level | self._log_level = log_level | ||||
self._sublinear_memory_config = sublinear_memory_config | self._sublinear_memory_config = sublinear_memory_config | ||||
self._allreduce_pack_max_size = allreduce_pack_max_size | |||||
self._status = self._UNSTARTED | self._status = self._UNSTARTED | ||||
self._args = None | self._args = None | ||||
self._kwargs = None | self._kwargs = None | ||||
@@ -313,6 +317,9 @@ class trace: | |||||
"sublinear_mem_cofig.num_worker", | "sublinear_mem_cofig.num_worker", | ||||
self._sublinear_memory_config.num_worker, | self._sublinear_memory_config.num_worker, | ||||
) | ) | ||||
# pack allreduce | |||||
if self._allreduce_pack_max_size is not None: | |||||
cg.set_option("allreduce_pack_max_size", self._allreduce_pack_max_size) | |||||
# profile | # profile | ||||
if self._profiling: | if self._profiling: | ||||
self._profiler = CompGraphProfiler(cg) | self._profiler = CompGraphProfiler(cg) | ||||
@@ -391,6 +398,7 @@ class trace: | |||||
outputs = [outputs] | outputs = [outputs] | ||||
# _run_wrapped has checked validity of outputs | # _run_wrapped has checked validity of outputs | ||||
self._sym_outputs = tuple(i._symvar for i in outputs) | self._sym_outputs = tuple(i._symvar for i in outputs) | ||||
mgb.comp_graph_tools.set_priority_to_id(self._outspec) | |||||
self._compiled_func = graph.get_default_graph().compile(None, self._outspec) | self._compiled_func = graph.get_default_graph().compile(None, self._outspec) | ||||
def trace(self, *args: Tensor, **kwargs): | def trace(self, *args: Tensor, **kwargs): | ||||
@@ -159,7 +159,6 @@ class Optimizer(metaclass=ABCMeta): | |||||
:param loss: The obtained loss tensor | :param loss: The obtained loss tensor | ||||
""" | """ | ||||
rst = [] | rst = [] | ||||
priority = 0 | |||||
params = [] | params = [] | ||||
for group in self.param_groups: | for group in self.param_groups: | ||||
for param in group["params"]: | for param in group["params"]: | ||||
@@ -180,14 +179,14 @@ class Optimizer(metaclass=ABCMeta): | |||||
for param, grad in zip(params, grads): | for param, grad in zip(params, grads): | ||||
if is_distributed(): | if is_distributed(): | ||||
priority += 1 | |||||
with opr_priority_scope(cg, -priority): | |||||
# all_reduce_mean | |||||
with opr_priority_scope(cg, -(2 ** 30)): | |||||
# always run all_reduce_mean first except add_update | |||||
grad = ( | grad = ( | ||||
all_reduce_sum(grad, "grad_" + str(get_group_id())) | all_reduce_sum(grad, "grad_" + str(get_group_id())) | ||||
/ get_world_size() | / get_world_size() | ||||
) | ) | ||||
with opr_priority_scope(cg, (1 << 30) - priority): | |||||
with opr_priority_scope(cg, -(2 ** 31)): | |||||
# always run add_update first | |||||
grad_update = add_update(param.grad, grad) | grad_update = add_update(param.grad, grad) | ||||
else: | else: | ||||
grad_update = add_update(param.grad, grad) | grad_update = add_update(param.grad, grad) | ||||
@@ -66,6 +66,8 @@ bool _config::set_comp_graph_option( | |||||
SET_CG_OPTION(graph_opt.jit); | SET_CG_OPTION(graph_opt.jit); | ||||
SET_CG_OPTION(graph_opt.tensorrt); | SET_CG_OPTION(graph_opt.tensorrt); | ||||
SET_CG_OPTION(graph_opt_level); | SET_CG_OPTION(graph_opt_level); | ||||
SET_CG_OPTION(allreduce_pack_max_size); | |||||
SET_CG_OPTION(allreduce_pack_ignore_first); | |||||
SET_CG_OPTION(var_sanity_check_first_run); | SET_CG_OPTION(var_sanity_check_first_run); | ||||
SET_CG_OPTION(no_profiling_on_shape_change); | SET_CG_OPTION(no_profiling_on_shape_change); | ||||
SET_CG_OPTION(allocate_static_mem_after_graph_compile); | SET_CG_OPTION(allocate_static_mem_after_graph_compile); | ||||
@@ -1,3 +1,7 @@ | |||||
%{ | |||||
#include "megbrain/gopt/framework.h" | |||||
%} | |||||
%inline { | %inline { | ||||
SymbolVarArray _get_owner_opr_inputs(SymbolVar var) { | SymbolVarArray _get_owner_opr_inputs(SymbolVar var) { | ||||
@@ -35,5 +39,17 @@ | |||||
} | } | ||||
return mgb::cg::replace_oprs(vars, oprmap); | return mgb::cg::replace_oprs(vars, oprmap); | ||||
} | } | ||||
void _set_priority_to_id(const SymbolVarArray& dest_vars) { | |||||
auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { | |||||
if (opr->node_prop().attribute().priority == 0) { | |||||
opr->node_prop().attribute().priority = opr->id(); | |||||
} | |||||
}; | |||||
mgb::cg::DepOprIter dep_iter{on_opr}; | |||||
for (const SymbolVar& var : dest_vars) { | |||||
dep_iter.add(var); | |||||
} | |||||
} | |||||
} | } | ||||
// vim: ft=swig foldmethod=marker foldmarker=f{{{,f}}} | // vim: ft=swig foldmethod=marker foldmarker=f{{{,f}}} |
@@ -441,12 +441,22 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||||
optimizer.verbosity(options().log_level); | optimizer.verbosity(options().log_level); | ||||
optimizer.enable_check_result(options().graph_opt_level < 0); | optimizer.enable_check_result(options().graph_opt_level < 0); | ||||
if (sopr_stat.has_virtual_grad) { | if (sopr_stat.has_virtual_grad) { | ||||
if (need_opt) | |||||
if (need_opt) { | |||||
#if MGB_ENABLE_OPR_MM | |||||
optimizer.add_pass<gopt::PackAllReduceScanPass>(); | |||||
#endif | |||||
optimizer.add_preset_passes(false, nullptr, &options()); | optimizer.add_preset_passes(false, nullptr, &options()); | ||||
} | |||||
optimizer.add_pass<gopt::ExpandVirtualGradPass>(); | optimizer.add_pass<gopt::ExpandVirtualGradPass>(); | ||||
} | } | ||||
if (need_opt) | |||||
if (need_opt) { | |||||
optimizer.add_preset_passes(true, nullptr, &options()); | optimizer.add_preset_passes(true, nullptr, &options()); | ||||
#if MGB_ENABLE_OPR_MM | |||||
if (sopr_stat.has_virtual_grad) { | |||||
optimizer.add_pass<gopt::PackAllReduceReplacePass>(); | |||||
} | |||||
#endif | |||||
} | |||||
optimizer.apply_inplace(dest_vars); | optimizer.apply_inplace(dest_vars); | ||||
} | } | ||||
#endif | #endif | ||||
@@ -328,6 +328,18 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||||
int16_t graph_opt_level = 2; | int16_t graph_opt_level = 2; | ||||
/*! | /*! | ||||
* max size of allreduce packs in MB | |||||
* set this option to zero to disable PackAllReducePass | |||||
*/ | |||||
int16_t allreduce_pack_max_size = 0; | |||||
/*! | |||||
* do not pack the first n allreduces | |||||
* PackAllReducePass disabled if allreduce_pack_max_size is zero | |||||
*/ | |||||
int16_t allreduce_pack_ignore_first = 2; | |||||
/*! | |||||
* set logging level, larger number means more verbose | * set logging level, larger number means more verbose | ||||
* 0: no log info | * 0: no log info | ||||
* 1: static memory allocation status | * 1: static memory allocation status | ||||
@@ -183,7 +183,6 @@ SymbolVarArray replace_oprs( | |||||
SymbolVarArray replace_vars_comp_graph( | SymbolVarArray replace_vars_comp_graph( | ||||
const SymbolVarArray &dest, ComputingGraph* new_graph); | const SymbolVarArray &dest, ComputingGraph* new_graph); | ||||
SymbolVarArray find_h2d(const SymbolVarArray& dest); | SymbolVarArray find_h2d(const SymbolVarArray& dest); | ||||
/*! | /*! | ||||
@@ -17,6 +17,7 @@ | |||||
#include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
#include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
#include "megbrain/serialization/opr_shallow_copy.h" | #include "megbrain/serialization/opr_shallow_copy.h" | ||||
#include "../../core/impl/graph/cg_impl.h" | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace gopt; | using namespace gopt; | ||||
@@ -657,4 +658,309 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const { | |||||
rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
} | } | ||||
#if MGB_ENABLE_OPR_MM | |||||
#include "megbrain/opr/collective_comm.h" | |||||
/* ======================= PackAllReduceScanPass ====================== */ | |||||
const char* PackAllReduceScanPass::name() const { | |||||
return "pack_allreduce_scan"; | |||||
} | |||||
void PackAllReduceScanPass::apply(OptState& opt) const { | |||||
auto comp_graph = opt.graph().comp_graph(); | |||||
if (comp_graph->options().allreduce_pack_max_size == 0) return; | |||||
auto cb_scan = [this] (OperatorNodeBase* opr) { | |||||
if (check_pattern(opr)) { | |||||
auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||||
VarNode* target = comm.input(0)->owner_opr()->input(0); | |||||
// only pack allreduces of grads of the same target | |||||
// in case two allreduces depend on each other | |||||
size_t id = target->id(); | |||||
uint64_t hash = XXHash().update(&id, sizeof(size_t)).digest(); | |||||
comm.set_pack_hash(hash); | |||||
} | |||||
}; | |||||
opt.graph().iter(cb_scan); | |||||
} | |||||
bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { | |||||
if (!opr->same_type<opr::CollectiveComm>()) return false; | |||||
auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||||
if (comm.param().mode != opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM) return false; | |||||
if (comm.input().size() != 1) return false; | |||||
auto grad = comm.input(0)->owner_opr(); | |||||
if (!grad->same_type<opr::VirtualGrad>()) return false; | |||||
if (grad->input().size() != 2 or grad->output().size() != 1) return false; | |||||
auto param = grad->input(1)->owner_opr(); | |||||
if (!param->same_type<opr::SharedDeviceTensor>() and | |||||
!param->same_type<opr::VolatileSharedDeviceTensor>()) return false; | |||||
if (param->input().size() != 0) return false; | |||||
return true; | |||||
} | |||||
/* ======================= PackAllReduceReplacePass ====================== */ | |||||
const char* PackAllReduceReplacePass::name() const { | |||||
return "pack_allreduce_replace"; | |||||
} | |||||
class PackAllReduceReplacePass::GroupInfo { | |||||
public: | |||||
GroupInfo(int _device, DType _dtype, | |||||
size_t _nr_devices, bool _is_root, int _rank, | |||||
std::shared_ptr<opr::GroupClient> _group_client, | |||||
const std::string& _backend); | |||||
uint64_t hash(uint64_t extra) const; | |||||
int device; | |||||
DType dtype; | |||||
size_t nr_devices; | |||||
bool is_root; | |||||
int rank; | |||||
std::shared_ptr<opr::GroupClient> group_client; | |||||
std::string backend; | |||||
}; | |||||
PackAllReduceReplacePass::GroupInfo::GroupInfo( | |||||
int _device, DType _dtype, | |||||
size_t _nr_devices, bool _is_root, int _rank, | |||||
std::shared_ptr<opr::GroupClient> _group_client, | |||||
const std::string& _backend) : | |||||
device(_device), dtype(_dtype), | |||||
nr_devices(_nr_devices), is_root(_is_root), rank(_rank), | |||||
group_client(_group_client), backend(_backend) { | |||||
} | |||||
uint64_t PackAllReduceReplacePass::GroupInfo::hash(uint64_t extra) const { | |||||
DTypeEnum ev = dtype.enumv(); | |||||
const std::string& server_addr = group_client->get_addr(); | |||||
return XXHash() | |||||
.update(&extra, sizeof(uint64_t)) | |||||
.update(&device, sizeof(int)) | |||||
.update(&ev, sizeof(DTypeEnum)) | |||||
.update(&nr_devices, sizeof(size_t)) | |||||
.update(&is_root, sizeof(bool)) | |||||
.update(&rank, sizeof(int)) | |||||
.update(server_addr.c_str(), server_addr.size()) | |||||
.update(backend.c_str(), backend.size()) | |||||
.digest(); | |||||
} | |||||
uint64_t PackAllReduceReplacePass::collect_groups(OperatorNodeBase* opr, | |||||
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info, | |||||
ThinHashMap<uint64_t, cg::OprNodeArray>& groups) { | |||||
// check CollectiveComm oprs that have been marked in PackAllReduceScanPass | |||||
if (!opr->same_type<opr::CollectiveComm>()) return 0; | |||||
opr::CollectiveComm& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||||
if (comm.pack_hash() == 0) return 0; // pack_hash not set | |||||
VarNode* var = comm.input(0); | |||||
auto info = std::make_shared<GroupInfo>( | |||||
var->comp_node().locator().device, | |||||
var->dtype(), | |||||
comm.nr_devices(), | |||||
comm.is_root(), | |||||
comm.rank(), | |||||
comm.group_client(), | |||||
comm.backend() | |||||
); | |||||
uint64_t hash = info->hash(comm.pack_hash()); | |||||
if (group_info.find(hash) == group_info.end()) { | |||||
group_info.emplace(hash, info); | |||||
} | |||||
groups[hash].push_back(opr); | |||||
return hash; | |||||
} | |||||
void PackAllReduceReplacePass::divide_packs( | |||||
const ThinHashMap<uint64_t, cg::OprNodeArray>& groups, | |||||
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs, | |||||
size_t max_size) { | |||||
cg::OprNodeArray pack; | |||||
size_t sum = 0; | |||||
for (auto it : groups) { | |||||
uint64_t hash = it.first; | |||||
const cg::OprNodeArray& group = it.second; | |||||
for (size_t i = 0; i < group.size(); i++) { | |||||
OperatorNodeBase* opr = group[i]; | |||||
VarNode* var = opr->input(0); | |||||
const TensorShape* shape = var->owner_graph() | |||||
->static_infer_manager().infer_shape_fallible(var); | |||||
if (shape == nullptr) continue; | |||||
pack.push_back(opr); | |||||
sum += var->dtype().size(shape->total_nr_elems()); | |||||
if (sum >= max_size) { | |||||
if (pack.size() > 1) packs[hash].push_back(pack); | |||||
pack.clear(); | |||||
sum = 0; | |||||
} | |||||
} | |||||
if (pack.size() > 1) packs[hash].push_back(pack); | |||||
pack.clear(); | |||||
sum = 0; | |||||
} | |||||
} | |||||
void PackAllReduceReplacePass::insert_packed_oprs( | |||||
size_t pack_id, | |||||
const cg::OprNodeArray& pack, | |||||
std::shared_ptr<GroupInfo> info, | |||||
ThinHashMap<VarNode*, VarNode*>& replace_map, int priority) { | |||||
// set priority | |||||
mgb_assert(pack.size() > 0); | |||||
auto graph = pack[0]->owner_graph(); | |||||
auto on_opr_inserted = [priority] (const cg::event::OprInserted& event) { | |||||
event.opr->node_prop().attribute().priority = priority; | |||||
}; | |||||
auto handler = graph->event().register_receiver<cg::event::OprInserted>(on_opr_inserted); | |||||
// flatten inputs and record shapes and partition | |||||
std::vector<SymbolVar> shapes; | |||||
SymbolVarArray flattens; | |||||
SymbolVarArray partition; | |||||
for (size_t i = 0; i < pack.size(); i++) { | |||||
VarNode* var = pack[i]->input(0); | |||||
auto shape = opr::GetVarShape::make(SymbolVar(var)); | |||||
shapes.push_back(shape); | |||||
SymbolVar flatten = SymbolVar(var).flatten(); | |||||
flattens.push_back(flatten); | |||||
partition.push_back(opr::Reduce::make(shape, {opr::Reduce::Mode::PRODUCT, 0})); | |||||
} | |||||
// concat | |||||
SymbolVar concat = opr::Concat::make(flattens, 0); | |||||
// allreduce | |||||
std::string key = ssprintf("grad_pack_%zu", pack_id); | |||||
auto param = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||||
SymbolVar allreduce = opr::CollectiveComm::make({concat}, graph, | |||||
key, info->nr_devices, info->is_root, info->rank, | |||||
info->group_client, param, info->dtype, info->backend)[0]; | |||||
// split according to recorded partition | |||||
SymbolVarArray splits = opr::Split::make(allreduce, | |||||
opr::Split::Options::make_partition(0, partition)); | |||||
// reshape and insert results into replace_map | |||||
mgb_assert(pack.size() == splits.size()); | |||||
for (size_t i = 0; i < pack.size(); i++) { | |||||
VarNode* reshape = splits[i].reshape(shapes[i]).node(); | |||||
replace_map[pack[i]->output(0)] = reshape; | |||||
} | |||||
} | |||||
void PackAllReduceReplacePass::apply(OptState& opt) const { | |||||
// get graph options | |||||
auto comp_graph = opt.graph().comp_graph(); | |||||
size_t max_size = comp_graph->options().allreduce_pack_max_size * 1024 * 1024; | |||||
size_t ignore_first = comp_graph->options().allreduce_pack_ignore_first; | |||||
if (max_size == 0) return; | |||||
// get topo order | |||||
auto& topo_sorter = static_cast<cg::ComputingGraphImpl*>(comp_graph)->topo_sorter(); | |||||
cg::CompSeqExtraInfo extra_info; | |||||
VarNodeArray endpoints = to_var_node_array(opt.graph().endpoint_vars()); | |||||
const cg::OprNodeArray* seq = topo_sorter.get_comp_seq(extra_info, endpoints); | |||||
topo_sorter.restore_opr_prop(); | |||||
// collect allreduce groups from topo sequence | |||||
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | |||||
ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||||
for (size_t i = 0; i < seq->size(); i++) { | |||||
if (seq->at(i)->same_type<opr::CollectiveComm>()) { | |||||
// ignore the first several allreduces | |||||
if (ignore_first > 0) { | |||||
--ignore_first; | |||||
} else { | |||||
collect_groups(seq->at(i), group_info, groups); | |||||
} | |||||
} | |||||
} | |||||
// divide groups into packs | |||||
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>> packs; | |||||
divide_packs(groups, packs, max_size); | |||||
// make sure that oprs inserted in this pass (reshape, concat, allreduce, | |||||
// split, reshape) have higher priority than existing operators | |||||
int priority = -seq->size() - 100; | |||||
// insert packed operators and generate replace_map | |||||
ThinHashMap<VarNode*, VarNode*> replace_map; | |||||
size_t pack_id = 0; | |||||
for (auto it : packs) { | |||||
uint64_t hash = it.first; | |||||
for (auto pack : it.second) { | |||||
opt.call_with_opr(pack[0], [&]() { | |||||
insert_packed_oprs(pack_id, pack, group_info[hash], replace_map, priority); | |||||
}, OprPropertyFlag::NONE); | |||||
pack_id += 1; | |||||
} | |||||
} | |||||
// replace vars | |||||
auto rewriter = opt.graph().make_rewriter(); | |||||
auto cb_replace = [&](OperatorNodeBase* opr) { | |||||
for (auto i : opr->input()) { | |||||
auto iter = replace_map.find(i); | |||||
if (iter != replace_map.end()) { | |||||
rewriter.replace_var(i, iter->second, nullptr); | |||||
} | |||||
} | |||||
rewriter.auto_replace_outputs(opr); | |||||
}; | |||||
opt.graph().iter(cb_replace); | |||||
rewriter.apply_inplace(); | |||||
} | |||||
#else | |||||
/* ======================= PackAllReduceScanPass ====================== */ | |||||
const char* PackAllReduceScanPass::name() const { | |||||
return "pack_allreduce_scan"; | |||||
} | |||||
void PackAllReduceScanPass::apply(OptState& opt) const { | |||||
} | |||||
bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { | |||||
return true; | |||||
} | |||||
/* ======================= PackAllReduceReplacePass ====================== */ | |||||
const char* PackAllReduceReplacePass::name() const { | |||||
return "pack_allreduce_replace"; | |||||
} | |||||
void PackAllReduceReplacePass::apply(OptState& opt) const {} | |||||
uint64_t PackAllReduceReplacePass::collect_groups( | |||||
OperatorNodeBase* opr, | |||||
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info, | |||||
ThinHashMap<uint64_t, cg::OprNodeArray>& groups) { | |||||
return 0; | |||||
} | |||||
void PackAllReduceReplacePass::divide_packs( | |||||
const ThinHashMap<uint64_t, cg::OprNodeArray>& groups, | |||||
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs, | |||||
size_t max_size) { | |||||
} | |||||
void PackAllReduceReplacePass::insert_packed_oprs( | |||||
size_t pack_id, | |||||
const cg::OprNodeArray& pack, | |||||
std::shared_ptr<GroupInfo> info, | |||||
ThinHashMap<VarNode*, VarNode*>& replace_map, int priority) { | |||||
} | |||||
#endif // MGB_ENABLE_OPR_MM | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -11,6 +11,8 @@ | |||||
#pragma once | #pragma once | ||||
#include <vector> | |||||
#include "megbrain/gopt/framework.h" | #include "megbrain/gopt/framework.h" | ||||
namespace mgb { | namespace mgb { | ||||
@@ -90,6 +92,45 @@ namespace gopt { | |||||
void apply(OptState& opt) const override; | void apply(OptState& opt) const override; | ||||
}; | }; | ||||
//! scan allreduces of param grads | |||||
class PackAllReduceScanPass final : public Pass { | |||||
public: | |||||
const char* name() const override; | |||||
void apply(OptState& opt) const override; | |||||
private: | |||||
// check pattern param -> grad -> allreduce | |||||
static bool check_pattern(OperatorNodeBase* opr); | |||||
}; | |||||
//! pack allreduces of param grads | |||||
class PackAllReduceReplacePass final : public Pass { | |||||
public: | |||||
class GroupInfo; | |||||
const char* name() const override; | |||||
void apply(OptState& opt) const override; | |||||
// collect allreduces and divide into groups | |||||
static uint64_t collect_groups( | |||||
OperatorNodeBase* opr, | |||||
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info, | |||||
ThinHashMap<uint64_t, cg::OprNodeArray>& groups); | |||||
// divide groups into packs, max_size in MB | |||||
static void divide_packs( | |||||
const ThinHashMap<uint64_t, cg::OprNodeArray>& groups, | |||||
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs, | |||||
size_t max_size); | |||||
// insert packed operators and update replace_map | |||||
static void insert_packed_oprs( | |||||
size_t pack_id, | |||||
const cg::OprNodeArray& pack, | |||||
std::shared_ptr<GroupInfo> info, | |||||
ThinHashMap<VarNode*, VarNode*>& replace_map, int priority); | |||||
}; | |||||
} // namespace gopt | } // namespace gopt | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -14,6 +14,7 @@ | |||||
#include "megbrain/gopt/basic_arith.h" | #include "megbrain/gopt/basic_arith.h" | ||||
#include "megbrain/gopt/misc.h" | #include "megbrain/gopt/misc.h" | ||||
#include "megbrain/opr/basic_arith_wrapper.h" | #include "megbrain/opr/basic_arith_wrapper.h" | ||||
#include "megbrain/opr/blas.h" | |||||
#include "megbrain/opr/cond.h" | #include "megbrain/opr/cond.h" | ||||
#include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
#include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
@@ -410,4 +411,322 @@ TEST_PASS(RemoveRedundantTypeCvtPass, Basic) { | |||||
check(x_q8_q8, x_q8_fp32_q8_); | check(x_q8_q8, x_q8_fp32_q8_); | ||||
} | } | ||||
#if MGB_ENABLE_OPR_MM | |||||
#include "megbrain/opr/collective_comm.h" | |||||
#include "../../opr-mm/test/mock_client.h" | |||||
TEST_PASS(PackAllReduceScanPass, Basic) { | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().allreduce_pack_max_size = 5000; | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto cn = CompNode::load("gpux"); | |||||
auto dev_x0 = std::make_shared<DeviceTensorND>(cn, TensorShape{3, 5}); | |||||
auto dev_x1 = std::make_shared<DeviceTensorND>(cn, TensorShape{4, 6}); | |||||
auto dev_y0 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}); | |||||
auto dev_y1 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}); | |||||
auto x0 = opr::SharedDeviceTensor::make(*graph, dev_x0); | |||||
auto x1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_x1); | |||||
auto y0 = opr::SharedDeviceTensor::make(*graph, dev_y0); | |||||
auto y1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_y1); | |||||
auto grad0 = opr::VirtualGrad::make(y0, x0); | |||||
auto grad1 = opr::VirtualGrad::make(y0, x1); | |||||
auto grad2 = opr::VirtualGrad::make(y1, x0); | |||||
auto grad3 = opr::VirtualGrad::make(y1, x1); | |||||
auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||||
auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), | |||||
"grad0", 2, 0, 0, client, mode)[0]; | |||||
auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), | |||||
"grad1", 2, 0, 0, client, mode)[0]; | |||||
auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), | |||||
"grad2", 2, 0, 0, client, mode)[0]; | |||||
auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), | |||||
"grad3", 2, 0, 0, client, mode)[0]; | |||||
gopt::GraphOptimizer() | |||||
.add_pass<gopt::PackAllReduceScanPass>() | |||||
.apply({{comm0, comm1, comm2, comm3}}); | |||||
auto get_hash = [] (const SymbolVar& symvar) { | |||||
cg::OperatorNodeBase* opr = symvar.node()->owner_opr(); | |||||
return opr->cast_final_safe<opr::CollectiveComm>().pack_hash(); | |||||
}; | |||||
uint64_t hash0 = get_hash(comm0); | |||||
uint64_t hash1 = get_hash(comm1); | |||||
uint64_t hash2 = get_hash(comm2); | |||||
uint64_t hash3 = get_hash(comm3); | |||||
ASSERT_EQ(hash0, hash1); | |||||
ASSERT_EQ(hash2, hash3); | |||||
ASSERT_NE(hash0, hash2); | |||||
} | |||||
TEST_PASS(PackAllReduceReplacePass, CollectGroups) { | |||||
REQUIRE_GPU(2); | |||||
auto cns = load_multiple_xpus(2); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 2; | |||||
auto cli0 = std::make_shared<test::MockGroupClient>("mock_addr0"); | |||||
auto cli1 = std::make_shared<test::MockGroupClient>("mock_addr1"); | |||||
using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; | |||||
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | |||||
ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||||
auto add_opr = [&] (const CompNode& cn, TensorShape shape, const DType& dt, | |||||
std::shared_ptr<test::MockGroupClient> client, uint64_t extra_hash) { | |||||
auto dev0 = std::make_shared<DeviceTensorND>(cn, shape, dt); | |||||
auto wrt = opr::SharedDeviceTensor::make(*graph, dev0); | |||||
auto dev1 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}, dt); | |||||
auto target = opr::SharedDeviceTensor::make(*graph, dev1); | |||||
auto grad = opr::VirtualGrad::make(target, wrt); | |||||
auto comm = opr::CollectiveComm::make( | |||||
{grad}, graph.get(), "key", 2, 0, 0, client, | |||||
opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0] | |||||
.node()->owner_opr(); | |||||
comm->cast_final_safe<opr::CollectiveComm>().set_pack_hash(extra_hash); | |||||
return gopt::PackAllReduceReplacePass::collect_groups(comm, group_info, groups); | |||||
}; | |||||
uint64_t hash0 = add_opr(cns[0], TensorShape{1, 3}, dtype::Float32{}, cli0, 1); | |||||
uint64_t hash1 = add_opr(cns[0], TensorShape{2, 4}, dtype::Float32{}, cli0, 1); // same | |||||
uint64_t hash2 = add_opr(cns[1], TensorShape{3, 5}, dtype::Float32{}, cli0, 1); // comp_node | |||||
uint64_t hash3 = add_opr(cns[0], TensorShape{4, 6}, dtype::Float16{}, cli0, 1); // dtype | |||||
uint64_t hash4 = add_opr(cns[0], TensorShape{5, 7}, dtype::Float32{}, cli1, 1); // client | |||||
uint64_t hash5 = add_opr(cns[0], TensorShape{6, 8}, dtype::Float32{}, cli0, 2); // extra_hash | |||||
ASSERT_EQ(hash0, hash1); | |||||
std::set<uint64_t> s; | |||||
s.insert(hash0); | |||||
s.insert(hash1); | |||||
s.insert(hash2); | |||||
s.insert(hash3); | |||||
s.insert(hash4); | |||||
s.insert(hash5); | |||||
ASSERT_EQ(5, s.size()); | |||||
ASSERT_EQ(1, group_info.count(hash0)); | |||||
ASSERT_EQ(1, group_info.count(hash1)); | |||||
ASSERT_EQ(1, group_info.count(hash2)); | |||||
ASSERT_EQ(1, group_info.count(hash3)); | |||||
ASSERT_EQ(1, group_info.count(hash4)); | |||||
ASSERT_EQ(1, group_info.count(hash5)); | |||||
ASSERT_EQ(2, groups[hash0].size()); | |||||
ASSERT_EQ(2, groups[hash1].size()); | |||||
ASSERT_EQ(1, groups[hash2].size()); | |||||
ASSERT_EQ(1, groups[hash3].size()); | |||||
ASSERT_EQ(1, groups[hash4].size()); | |||||
ASSERT_EQ(1, groups[hash5].size()); | |||||
} | |||||
TEST_PASS(PackAllReduceReplacePass, DividePacks) { | |||||
auto cn = CompNode::load("gpux"); | |||||
auto graph = ComputingGraph::make(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||||
ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||||
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>> packs; | |||||
auto insert_opr = [&] (size_t size) { | |||||
auto dev = std::make_shared<DeviceTensorND>(cn, TensorShape{size / sizeof(float)}); | |||||
auto sd = opr::SharedDeviceTensor::make(*graph, dev); | |||||
auto symvar = opr::CollectiveComm::make({sd}, graph.get(), | |||||
"key", 2, 0, 0, client, mode)[0]; | |||||
auto opr = symvar.node()->owner_opr(); | |||||
auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||||
comm.set_pack_hash(1); | |||||
return opr; | |||||
}; | |||||
auto pack_size = [&] (cg::OprNodeArray& pack) { | |||||
size_t sum = 0; | |||||
for (size_t i = 0; i < pack.size(); i++) { | |||||
auto var = pack[i]->input(0); | |||||
sum += var->dtype().size(var->shape().total_nr_elems()); | |||||
} | |||||
return sum; | |||||
}; | |||||
groups[0].push_back(insert_opr(100)); // group0, pack0, size=1100 | |||||
groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100 | |||||
groups[0].push_back(insert_opr(400)); // group0, pack0, size=1100 | |||||
groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100 | |||||
groups[0].push_back(insert_opr(500)); // group0, pack1, size=800 | |||||
groups[0].push_back(insert_opr(200)); // group0, pack1, size=800 | |||||
groups[0].push_back(insert_opr(100)); // group0, pack1, size=800 | |||||
groups[1].push_back(insert_opr(100)); // group1, pack0, size=900 | |||||
groups[1].push_back(insert_opr(400)); // group1, pack0, size=900 | |||||
groups[1].push_back(insert_opr(300)); // group1, pack0, size=900 | |||||
groups[1].push_back(insert_opr(100)); // group1, pack0, size=900 | |||||
gopt::PackAllReduceReplacePass::divide_packs(groups, packs, 1000); | |||||
ASSERT_EQ(2, packs.size()); | |||||
ASSERT_EQ(2, packs[0].size()); | |||||
ASSERT_EQ(4, packs[0][0].size()); | |||||
ASSERT_EQ(1100, pack_size(packs[0][0])); | |||||
ASSERT_EQ(3, packs[0][1].size()); | |||||
ASSERT_EQ(800, pack_size(packs[0][1])); | |||||
ASSERT_EQ(1, packs[1].size()); | |||||
ASSERT_EQ(4, packs[1][0].size()); | |||||
ASSERT_EQ(900, pack_size(packs[1][0])); | |||||
} | |||||
TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { | |||||
auto cn = CompNode::load("gpux"); | |||||
auto graph = ComputingGraph::make(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||||
size_t nr_devices = 2; | |||||
uint32_t rank = 0; | |||||
uint32_t root = 0; | |||||
using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; | |||||
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | |||||
ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||||
auto insert_opr = [&] (const TensorShape& shape) { | |||||
auto dev = std::make_shared<DeviceTensorND>(cn, shape); | |||||
auto sd = opr::SharedDeviceTensor::make(*graph, dev); | |||||
auto symvar = opr::CollectiveComm::make({sd}, graph.get(), | |||||
"key", nr_devices, rank, root, client, mode)[0]; | |||||
auto opr = symvar.node()->owner_opr(); | |||||
auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||||
comm.set_pack_hash(1); | |||||
gopt::PackAllReduceReplacePass::collect_groups(opr, group_info, groups); | |||||
return symvar; | |||||
}; | |||||
auto shape_x = TensorShape{100, 200}; | |||||
auto shape_y = TensorShape{200, 400}; | |||||
auto x = insert_opr(shape_x); | |||||
auto y = insert_opr(shape_y); | |||||
ASSERT_EQ(1, group_info.size()); | |||||
ASSERT_EQ(1, groups.size()); | |||||
auto info = group_info.begin()->second; | |||||
auto pack = groups.begin()->second; | |||||
size_t pack_id = 0; | |||||
ThinHashMap<VarNode*, VarNode*> replace_map; | |||||
gopt::PackAllReduceReplacePass::insert_packed_oprs(pack_id, pack, info, replace_map, -1); | |||||
auto grad_x = SymbolVar(x.node()->owner_opr()->input(0)); | |||||
auto grad_y = SymbolVar(y.node()->owner_opr()->input(0)); | |||||
auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0); | |||||
std::string key = ssprintf("grad_pack_%zu", pack_id); | |||||
auto allreduce = opr::CollectiveComm::make({concat}, graph.get(), | |||||
key, nr_devices, rank, root, client, mode)[0]; | |||||
std::vector<size_t> partition; | |||||
partition.push_back(shape_x.total_nr_elems()); | |||||
partition.push_back(shape_y.total_nr_elems()); | |||||
auto splits = opr::Split::make(allreduce, | |||||
opr::Split::Options::make_partition(allreduce, 0, partition)); | |||||
ASSERT_EQ(2, splits.size()); | |||||
auto dest_x = splits[0].reshape(shape_x); | |||||
auto dest_y = splits[1].reshape(shape_y); | |||||
ASSERT_EQ(2, replace_map.size()); | |||||
ASSERT_TRUE(replace_map.count(x.node()) > 0); | |||||
ASSERT_EQ(replace_map.at(x.node()), dest_x.node()); | |||||
ASSERT_TRUE(replace_map.count(y.node()) > 0); | |||||
ASSERT_EQ(replace_map.at(y.node()), dest_y.node()); | |||||
} | |||||
TEST_PASS(PackAllReduceReplacePass, Equivalence) { | |||||
REQUIRE_GPU(2); | |||||
auto cns = load_multiple_xpus(2); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto build_graph = [&] (uint32_t rank, std::shared_ptr<ComputingGraph> graph, | |||||
SymbolVarArray& array) { | |||||
HostTensorGenerator<> gen; | |||||
auto cn = cns[rank]; | |||||
auto host_x = gen({1, 1000}); | |||||
auto host_y = gen({1000, 1}); | |||||
auto dev_x = std::make_shared<DeviceTensorND>(cn); | |||||
auto dev_y = std::make_shared<DeviceTensorND>(cn); | |||||
dev_x->copy_from(*host_x).sync(); | |||||
dev_y->copy_from(*host_y).sync(); | |||||
auto x = opr::SharedDeviceTensor::make(*graph, dev_x); | |||||
auto y = opr::VolatileSharedDeviceTensor::make(*graph, dev_y); | |||||
auto loss = opr::MatrixMul::make(x, y).flatten(); | |||||
auto grad_x = opr::VirtualGrad::make(loss, x); | |||||
auto grad_y = opr::VirtualGrad::make(loss, y); | |||||
using Mode = opr::CollectiveComm::Param::Mode; | |||||
bool is_root = (rank == 0); | |||||
auto reduced_x = opr::CollectiveComm::make({grad_x}, graph.get(), | |||||
"x", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; | |||||
auto reduced_y = opr::CollectiveComm::make({grad_y}, graph.get(), | |||||
"y", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; | |||||
graph->options().allreduce_pack_max_size = 5000; | |||||
graph->options().allreduce_pack_ignore_first = 0; | |||||
auto dest_vars = gopt::GraphOptimizer{} | |||||
.add_pass<gopt::PackAllReduceScanPass>() | |||||
.add_pass<gopt::PackAllReduceReplacePass>() | |||||
.apply({{reduced_x, reduced_y}}).endpoint_vars(); | |||||
array.emplace_back(reduced_x); | |||||
array.emplace_back(reduced_y); | |||||
array.emplace_back(dest_vars[0]); | |||||
array.emplace_back(dest_vars[1]); | |||||
}; | |||||
auto run = [&] (uint32_t rank) { | |||||
auto graph = ComputingGraph::make(); | |||||
SymbolVarArray array; | |||||
build_graph(rank, graph, array); | |||||
HostTensorND host_reduced_x, host_reduced_y, host_dest_0, host_dest_1; | |||||
graph->options().allreduce_pack_max_size = 0; | |||||
auto func = graph->compile({make_callback_copy(array[0], host_reduced_x), | |||||
make_callback_copy(array[1], host_reduced_y), | |||||
make_callback_copy(array[2], host_dest_0), | |||||
make_callback_copy(array[3], host_dest_1)}); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(host_reduced_x, host_dest_0); | |||||
MGB_ASSERT_TENSOR_EQ(host_reduced_y, host_dest_1); | |||||
}; | |||||
std::thread t0(run, 0); | |||||
std::thread t1(run, 1); | |||||
t0.join(); | |||||
t1.join(); | |||||
} | |||||
#endif // MGB_ENABLE_OPR_MM | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -461,16 +461,7 @@ void CollectiveComm::opr_register() { | |||||
m_rank = reg_info.rank; | m_rank = reg_info.rank; | ||||
m_root = reg_info.root_rank; | m_root = reg_info.root_rank; | ||||
MegRayCommunicatorBuilder* builder; | |||||
{ | |||||
static std::mutex user_data_mtx; | |||||
std::unique_lock<std::mutex> lk(user_data_mtx); | |||||
builder = owner_graph()->options().user_data | |||||
.get_user_data_or_create<MegRayCommunicatorBuilder>(); | |||||
} | |||||
m_megray_comm = builder->get_megray_comm( | |||||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||||
reg_info.hash, m_key, m_nr_devices, m_rank, | reg_info.hash, m_key, m_nr_devices, m_rank, | ||||
get_megray_backend(m_backend), m_group_client); | get_megray_backend(m_backend), m_group_client); | ||||
@@ -736,13 +727,15 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm( | |||||
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | ||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>(); | auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>(); | ||||
return opr::CollectiveComm::make( | |||||
auto new_opr = CollectiveComm::make( | |||||
to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), | to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), | ||||
opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), | opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), | ||||
opr.group_client(), opr.dev_buffers(), opr.param(), | opr.group_client(), opr.dev_buffers(), opr.param(), | ||||
opr.dtype(), opr.backend(), config)[0] | opr.dtype(), opr.backend(), config)[0] | ||||
.node() | .node() | ||||
->owner_opr(); | ->owner_opr(); | ||||
new_opr->cast_final_safe<opr::CollectiveComm>().set_pack_hash(opr.pack_hash()); | |||||
return new_opr; | |||||
} | } | ||||
MGB_REG_OPR_SHALLOW_COPY(CollectiveComm, opr_shallow_copy_collective_mm); | MGB_REG_OPR_SHALLOW_COPY(CollectiveComm, opr_shallow_copy_collective_mm); | ||||
@@ -54,13 +54,7 @@ void RemoteSend::scn_do_execute() { | |||||
auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false, | auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false, | ||||
comp_node.get_uid()); | comp_node.get_uid()); | ||||
auto megray_comm_builder = | |||||
owner_graph() | |||||
->options() | |||||
.user_data | |||||
.get_user_data_or_create<MegRayCommunicatorBuilder>(); | |||||
m_megray_comm = megray_comm_builder->get_megray_comm( | |||||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||||
reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | ||||
m_init = true; | m_init = true; | ||||
} | } | ||||
@@ -158,13 +152,7 @@ void RemoteRecv::scn_do_execute() { | |||||
m_peer.key, 2, false, 1, | m_peer.key, 2, false, 1, | ||||
comp_node.get_uid()); | comp_node.get_uid()); | ||||
auto megray_comm_builder = | |||||
owner_graph() | |||||
->options() | |||||
.user_data | |||||
.get_user_data_or_create<MegRayCommunicatorBuilder>(); | |||||
m_megray_comm = megray_comm_builder->get_megray_comm( | |||||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||||
reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | ||||
m_init = true; | m_init = true; | ||||
} | } | ||||
@@ -14,8 +14,8 @@ | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace opr; | using namespace opr; | ||||
bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { | |||||
std::unique_lock<std::mutex> lk(m_mtx); | |||||
bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { | |||||
std::unique_lock<std::mutex> lk(m_map_mtx); | |||||
auto it = m_megray_comms.find(hash); | auto it = m_megray_comms.find(hash); | ||||
if (it != m_megray_comms.end()) { | if (it != m_megray_comms.end()) { | ||||
comm = it->second; | comm = it->second; | ||||
@@ -24,27 +24,37 @@ bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Comm | |||||
return false; | return false; | ||||
} | } | ||||
void MegRayCommunicatorBuilder::emplace(uint64_t hash, | |||||
void MegRayCommBuilder::emplace(uint64_t hash, | |||||
std::shared_ptr<MegRay::Communicator> comm) { | std::shared_ptr<MegRay::Communicator> comm) { | ||||
std::unique_lock<std::mutex> lk(m_mtx); | |||||
std::unique_lock<std::mutex> lk(m_map_mtx); | |||||
m_megray_comms.emplace(hash, comm); | m_megray_comms.emplace(hash, comm); | ||||
} | } | ||||
std::shared_ptr<MegRay::Communicator> MegRayCommunicatorBuilder::get_megray_comm( | |||||
std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||||
uint64_t hash, std::string key, uint32_t size, uint32_t rank, | uint64_t hash, std::string key, uint32_t size, uint32_t rank, | ||||
MegRay::Backend backend, | MegRay::Backend backend, | ||||
std::shared_ptr<mgb::opr::GroupClient> group_client) { | std::shared_ptr<mgb::opr::GroupClient> group_client) { | ||||
{ | |||||
// singleton pattern | |||||
std::unique_lock<std::mutex> lk(sm_instance_mtx); | |||||
if (sm_instance == nullptr) { | |||||
sm_instance = new MegRayCommBuilder(); | |||||
} | |||||
} | |||||
std::shared_ptr<MegRay::Communicator> comm; | std::shared_ptr<MegRay::Communicator> comm; | ||||
if (!find(hash, comm)) { | |||||
if (!sm_instance->find(hash, comm)) { | |||||
comm = MegRay::get_communicator(size, rank, backend); | comm = MegRay::get_communicator(size, rank, backend); | ||||
auto uid = comm->get_uid(); | auto uid = comm->get_uid(); | ||||
auto uids = group_client->gather_uid(uid, key, size, rank); | auto uids = group_client->gather_uid(uid, key, size, rank); | ||||
mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK); | mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK); | ||||
emplace(hash, comm); | |||||
sm_instance->emplace(hash, comm); | |||||
} | } | ||||
return comm; | return comm; | ||||
} | } | ||||
MGB_TYPEINFO_OBJ_IMPL(MegRayCommunicatorBuilder); | |||||
MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr; | |||||
std::mutex MegRayCommBuilder::sm_instance_mtx; | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -81,6 +81,10 @@ public: | |||||
return m_group_client; | return m_group_client; | ||||
} | } | ||||
void set_pack_hash(uint64_t hash) { m_pack_hash = hash; } | |||||
uint64_t pack_hash() const { return m_pack_hash; } | |||||
std::shared_ptr<MegRay::Context> megray_ctx() const { | std::shared_ptr<MegRay::Context> megray_ctx() const { | ||||
return m_megray_ctx; | return m_megray_ctx; | ||||
} | } | ||||
@@ -123,6 +127,9 @@ private: | |||||
// whose shape infer should be disabled *during* static infer phase. | // whose shape infer should be disabled *during* static infer phase. | ||||
bool m_enable_shape_infer = false; | bool m_enable_shape_infer = false; | ||||
//! set in PackAllReduceScanPass and used in PackAllReduceReplacePass | |||||
uint64_t m_pack_hash = 0; | |||||
std::shared_ptr<MegRay::Context> m_megray_ctx; | std::shared_ptr<MegRay::Context> m_megray_ctx; | ||||
std::shared_ptr<MegRay::Communicator> m_megray_comm; | std::shared_ptr<MegRay::Communicator> m_megray_comm; | ||||
bool m_init = false; | bool m_init = false; | ||||
@@ -126,6 +126,8 @@ class GroupClient { | |||||
virtual ~GroupClient() = default; | virtual ~GroupClient() = default; | ||||
public: | public: | ||||
virtual const std::string& get_addr() const = 0; | |||||
virtual GroupManager::RegisterInfo opr_register(const std::string& key, | virtual GroupManager::RegisterInfo opr_register(const std::string& key, | ||||
size_t nr_devices, | size_t nr_devices, | ||||
bool is_root, int rank, | bool is_root, int rank, | ||||
@@ -23,18 +23,19 @@ namespace opr { | |||||
/*! | /*! | ||||
* gather MegRay unique ids and build communicator, use hash for deduplication | * gather MegRay unique ids and build communicator, use hash for deduplication | ||||
*/ | */ | ||||
class MegRayCommunicatorBuilder final : public mgb::UserDataContainer::UserData { | |||||
MGB_TYPEINFO_OBJ_DECL; | |||||
class MegRayCommBuilder { | |||||
private: | private: | ||||
bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm); | bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm); | ||||
void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | ||||
std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms; | std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms; | ||||
std::mutex m_mtx; | |||||
std::mutex m_map_mtx; | |||||
static MegRayCommBuilder* sm_instance; | |||||
static std::mutex sm_instance_mtx; | |||||
public: | public: | ||||
std::shared_ptr<MegRay::Communicator> get_megray_comm( | |||||
static std::shared_ptr<MegRay::Communicator> get_megray_comm( | |||||
uint64_t hash, std::string key, uint32_t size, uint32_t rank, | uint64_t hash, std::string key, uint32_t size, uint32_t rank, | ||||
MegRay::Backend backend, | MegRay::Backend backend, | ||||
std::shared_ptr<mgb::opr::GroupClient> group_client); | std::shared_ptr<mgb::opr::GroupClient> group_client); | ||||
@@ -47,7 +47,7 @@ public: | |||||
uint32_t group_barrier(uint32_t size, uint32_t rank) override; | uint32_t group_barrier(uint32_t size, uint32_t rank) override; | ||||
const std::string& get_addr() const { | |||||
const std::string& get_addr() const override { | |||||
return m_addr; | return m_addr; | ||||
} | } | ||||
@@ -17,11 +17,10 @@ | |||||
#include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
#include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
#include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
#include "mock_client.h" | |||||
using namespace mgb; | using namespace mgb; | ||||
namespace { | |||||
using Mode = opr::CollectiveComm::Param::Mode; | using Mode = opr::CollectiveComm::Param::Mode; | ||||
SymbolVar make_all_reduce_output(const Mode mode, | SymbolVar make_all_reduce_output(const Mode mode, | ||||
@@ -41,41 +40,6 @@ SymbolVarArray make_reduce_scatter_sum_output(const SymbolVarArray& inputs) { | |||||
rdc, opr::Split::Options::make_average(0, inputs.size())); | rdc, opr::Split::Options::make_average(0, inputs.size())); | ||||
} | } | ||||
class MockGroupClient final : public opr::GroupClient { | |||||
public: | |||||
~MockGroupClient() override = default; | |||||
opr::GroupManager::RegisterInfo opr_register(const std::string& key, | |||||
size_t nr_devices, | |||||
bool is_root, int rank, | |||||
uintptr_t stream) { | |||||
return m_mgr.opr_register(key, nr_devices, is_root, rank, stream); | |||||
} | |||||
std::vector<std::string> gather_uid(const std::string& uid, | |||||
const std::string& key, uint32_t size, uint32_t rank) { | |||||
return m_mgr.gather_uid(uid, key, size, rank); | |||||
} | |||||
void set_output_shape(const std::string& key, | |||||
const TensorShape& shape) override { | |||||
m_mgr.set_output_shape(key, shape); | |||||
} | |||||
TensorShape get_output_shape(const std::string& key) override { | |||||
return m_mgr.get_output_shape(key); | |||||
} | |||||
uint32_t group_barrier(uint32_t size, uint32_t rank) override { | |||||
return m_mgr.group_barrier(size, rank); | |||||
} | |||||
private: | |||||
opr::GroupManager m_mgr; | |||||
}; | |||||
} // namespace | |||||
TEST(TestOprCollectiveComm, AllReduce) { | TEST(TestOprCollectiveComm, AllReduce) { | ||||
REQUIRE_GPU(2); | REQUIRE_GPU(2); | ||||
@@ -88,7 +52,7 @@ TEST(TestOprCollectiveComm, AllReduce) { | |||||
auto host_x1 = gen({28, 28}); | auto host_x1 = gen({28, 28}); | ||||
HostTensorND host_y0, host_y1, host_y_expect; | HostTensorND host_y0, host_y1, host_y_expect; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | ||||
@@ -126,7 +90,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) { | |||||
auto host_x1 = gen({28, 28}); | auto host_x1 = gen({28, 28}); | ||||
HostTensorND host_y0, host_y1, host_y_expect; | HostTensorND host_y0, host_y1, host_y_expect; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto run_0 = [&]() { | auto run_0 = [&]() { | ||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
@@ -187,7 +151,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { | |||||
HostTensorND host_y0, host_y1, host_y_expect; | HostTensorND host_y0, host_y1, host_y_expect; | ||||
HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; | HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
@@ -268,7 +232,7 @@ TEST(TestOprCollectiveComm, AllGather) { | |||||
auto host_x1 = gen({28, 28}); | auto host_x1 = gen({28, 28}); | ||||
HostTensorND host_y0, host_y1, host_y_expect; | HostTensorND host_y0, host_y1, host_y_expect; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | ||||
@@ -300,7 +264,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) { | |||||
auto host_x1 = gen({28, 28}); | auto host_x1 = gen({28, 28}); | ||||
HostTensorND host_y0, host_y1, host_y_expect; | HostTensorND host_y0, host_y1, host_y_expect; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
@@ -356,7 +320,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { | |||||
HostTensorND host_out_grad0, host_out_grad1; | HostTensorND host_out_grad0, host_out_grad1; | ||||
HostTensorND host_out_grad0_expect, host_out_grad1_expect; | HostTensorND host_out_grad0_expect, host_out_grad1_expect; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
@@ -438,7 +402,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSum) { | |||||
auto host_x1 = gen({28, 28}); | auto host_x1 = gen({28, 28}); | ||||
HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | ||||
@@ -471,7 +435,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) { | |||||
auto host_x1 = gen({8}); | auto host_x1 = gen({8}); | ||||
HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
@@ -528,7 +492,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) { | |||||
HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | ||||
HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; | HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
@@ -610,7 +574,7 @@ TEST(TestOprCollectiveComm, ReduceSum) { | |||||
auto host_x1 = gen({28, 28}); | auto host_x1 = gen({28, 28}); | ||||
HostTensorND host_y0, host_y1, host_y_expect; | HostTensorND host_y0, host_y1, host_y_expect; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | ||||
@@ -641,7 +605,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) { | |||||
auto host_x1 = gen({28, 28}); | auto host_x1 = gen({28, 28}); | ||||
HostTensorND host_y0, host_y_expect; | HostTensorND host_y0, host_y_expect; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
@@ -694,7 +658,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { | |||||
HostTensorND host_y0, host_y0_expect, host_out_grad0, host_out_grad1; | HostTensorND host_y0, host_y0_expect, host_out_grad0, host_out_grad1; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
@@ -764,7 +728,7 @@ TEST(TestOprCollectiveComm, Broadcast) { | |||||
auto host_x0 = gen({28, 28}); | auto host_x0 = gen({28, 28}); | ||||
HostTensorND host_y0, host_y1, host_y_expect; | HostTensorND host_y0, host_y1, host_y_expect; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | ||||
@@ -794,7 +758,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) { | |||||
auto host_x0 = gen({28, 28}); | auto host_x0 = gen({28, 28}); | ||||
HostTensorND host_y0, host_y1; | HostTensorND host_y0, host_y1; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
@@ -840,7 +804,7 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) { | |||||
HostTensorND host_y0, host_y1, host_out_grad, host_out_grad_expect; | HostTensorND host_y0, host_y1, host_out_grad, host_out_grad_expect; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
@@ -14,51 +14,14 @@ | |||||
#include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
#include "megbrain/system.h" | #include "megbrain/system.h" | ||||
#include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
#include "mock_client.h" | |||||
#include <thread> | #include <thread> | ||||
using namespace mgb; | using namespace mgb; | ||||
using namespace opr; | |||||
namespace { | |||||
class MockGroupClient final : public opr::GroupClient { | |||||
public: | |||||
~MockGroupClient() override = default; | |||||
opr::GroupManager::RegisterInfo opr_register(const std::string& key, | |||||
size_t nr_devices, | |||||
bool is_root, int rank, | |||||
uint64_t comp_node_hash) { | |||||
return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); | |||||
} | |||||
std::vector<std::string> gather_uid(const std::string& uid, | |||||
const std::string& key, uint32_t size, uint32_t rank) { | |||||
return m_mgr.gather_uid(uid, key, size, rank); | |||||
} | |||||
void set_output_shape(const std::string& key, | |||||
const TensorShape& shape) override { | |||||
m_mgr.set_output_shape(key, shape); | |||||
} | |||||
TensorShape get_output_shape(const std::string& key) override { | |||||
return m_mgr.get_output_shape(key); | |||||
} | |||||
uint32_t group_barrier(uint32_t size, uint32_t rank) override { | |||||
return m_mgr.group_barrier(size, rank); | |||||
} | |||||
private: | |||||
opr::GroupManager m_mgr; | |||||
}; | |||||
const auto send_tag = RemoteIOBase::Type::SEND; | |||||
const auto recv_tag = RemoteIOBase::Type::RECV; | |||||
} // anonymous namespace | |||||
const auto send_tag = opr::RemoteIOBase::Type::SEND; | |||||
const auto recv_tag = opr::RemoteIOBase::Type::RECV; | |||||
TEST(TestOprIORemote, Identity) { | TEST(TestOprIORemote, Identity) { | ||||
REQUIRE_GPU(2); | REQUIRE_GPU(2); | ||||
@@ -69,7 +32,7 @@ TEST(TestOprIORemote, Identity) { | |||||
auto host_x = gen({28, 28}); | auto host_x = gen({28, 28}); | ||||
HostTensorND host_y; | HostTensorND host_y; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); | auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); | ||||
@@ -90,7 +53,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
auto host_x = gen({2, 3}, cns[1]); | auto host_x = gen({2, 3}, cns[1]); | ||||
HostTensorND host_x_get; | HostTensorND host_x_get; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto sender = [&]() { | auto sender = [&]() { | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
@@ -123,7 +86,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
auto host_x = gen({2, 3}, cns[0]); | auto host_x = gen({2, 3}, cns[0]); | ||||
HostTensorND host_x_get; | HostTensorND host_x_get; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto sender = [&]() { | auto sender = [&]() { | ||||
sys::set_thread_name("sender"); | sys::set_thread_name("sender"); | ||||
@@ -157,7 +120,7 @@ TEST(TestOprIORemote, APlusB) { | |||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
auto host_x = gen({5, 7}, cns[0]), host_y = gen({5, 1}, cns[0]); | auto host_x = gen({5, 7}, cns[0]), host_y = gen({5, 1}, cns[0]); | ||||
HostTensorND host_z; | HostTensorND host_z; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto sender = [&]() { | auto sender = [&]() { | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
@@ -208,7 +171,7 @@ TEST(TestOprIORemote, SendGrad) { | |||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
auto host_x = gen({2, 3}, cns[0]); | auto host_x = gen({2, 3}, cns[0]); | ||||
HostTensorND host_gx, host_loss; | HostTensorND host_gx, host_loss; | ||||
auto client = std::make_shared<MockGroupClient>(); | |||||
auto client = std::make_shared<test::MockGroupClient>(); | |||||
auto sender = [&]() { | auto sender = [&]() { | ||||
sys::set_thread_name("sender"); | sys::set_thread_name("sender"); | ||||
@@ -0,0 +1,62 @@ | |||||
/** | |||||
* \file src/opr-mm/test/mock_client.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/opr/group_manager.h" | |||||
namespace mgb { | |||||
namespace test { | |||||
class MockGroupClient final : public opr::GroupClient { | |||||
public: | |||||
using RegisterInfo = opr::GroupManager::RegisterInfo; | |||||
MockGroupClient(const std::string& server_addr = "mock_addr") : | |||||
m_addr(server_addr) { | |||||
} | |||||
~MockGroupClient() override = default; | |||||
const std::string& get_addr() const { | |||||
return m_addr; | |||||
} | |||||
RegisterInfo opr_register(const std::string& key, size_t nr_devices, | |||||
bool is_root, int rank, uint64_t comp_node_hash) { | |||||
return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); | |||||
} | |||||
std::vector<std::string> gather_uid(const std::string& uid, | |||||
const std::string& key, uint32_t size, uint32_t rank) { | |||||
return m_mgr.gather_uid(uid, key, size, rank); | |||||
} | |||||
void set_output_shape(const std::string& key, | |||||
const TensorShape& shape) override { | |||||
m_mgr.set_output_shape(key, shape); | |||||
} | |||||
TensorShape get_output_shape(const std::string& key) override { | |||||
return m_mgr.get_output_shape(key); | |||||
} | |||||
uint32_t group_barrier(uint32_t size, uint32_t rank) override { | |||||
return m_mgr.group_barrier(size, rank); | |||||
} | |||||
private: | |||||
const std::string m_addr; | |||||
opr::GroupManager m_mgr; | |||||
}; | |||||
} // namespace test | |||||
} // namespace mgb | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |