@@ -260,3 +260,15 @@ def replace_oprs(dst, oprmap): | |||
repl_dst_vec.push_back(j) | |||
return _mgb._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) | |||
def set_priority_to_id(dest_vars): | |||
"""For all oprs in the subgraph constructed by dest_vars | |||
set its priority to id if its original priority is zero | |||
:param dest_vars: target vars representing the graph | |||
""" | |||
dest_vec = _mgb._VectorSymbolVar() | |||
for i in dest_vars: | |||
assert isinstance(i, _mgb.SymbolVar) | |||
dest_vec.push_back(i) | |||
_mgb._set_priority_to_id(dest_vec) |
@@ -84,6 +84,8 @@ class trace: | |||
:param log_level: Log level. | |||
:param sublinear_memory_config: Configuration for sublinear memory optimization. | |||
If not None, it enables sublinear memory optimization with given setting. | |||
:param allreduce_pack_max_size: Maximum size of an allreduce pack in MB. | |||
If not None, multiple gradients will be packed and synchronized together | |||
:param profiling: Whether to profile compiled trace. Default: False | |||
""" | |||
@@ -107,6 +109,7 @@ class trace: | |||
opt_level: int = None, | |||
log_level: int = None, | |||
sublinear_memory_config: SublinearMemoryConfig = None, | |||
allreduce_pack_max_size: int = None, | |||
profiling: bool = False | |||
): | |||
self.__wrapped__ = func | |||
@@ -114,6 +117,7 @@ class trace: | |||
self._graph_opt_level = opt_level | |||
self._log_level = log_level | |||
self._sublinear_memory_config = sublinear_memory_config | |||
self._allreduce_pack_max_size = allreduce_pack_max_size | |||
self._status = self._UNSTARTED | |||
self._args = None | |||
self._kwargs = None | |||
@@ -313,6 +317,9 @@ class trace: | |||
"sublinear_mem_cofig.num_worker", | |||
self._sublinear_memory_config.num_worker, | |||
) | |||
# pack allreduce | |||
if self._allreduce_pack_max_size is not None: | |||
cg.set_option("allreduce_pack_max_size", self._allreduce_pack_max_size) | |||
# profile | |||
if self._profiling: | |||
self._profiler = CompGraphProfiler(cg) | |||
@@ -391,6 +398,7 @@ class trace: | |||
outputs = [outputs] | |||
# _run_wrapped has checked validity of outputs | |||
self._sym_outputs = tuple(i._symvar for i in outputs) | |||
mgb.comp_graph_tools.set_priority_to_id(self._outspec) | |||
self._compiled_func = graph.get_default_graph().compile(None, self._outspec) | |||
def trace(self, *args: Tensor, **kwargs): | |||
@@ -159,7 +159,6 @@ class Optimizer(metaclass=ABCMeta): | |||
:param loss: The obtained loss tensor | |||
""" | |||
rst = [] | |||
priority = 0 | |||
params = [] | |||
for group in self.param_groups: | |||
for param in group["params"]: | |||
@@ -180,14 +179,14 @@ class Optimizer(metaclass=ABCMeta): | |||
for param, grad in zip(params, grads): | |||
if is_distributed(): | |||
priority += 1 | |||
with opr_priority_scope(cg, -priority): | |||
# all_reduce_mean | |||
with opr_priority_scope(cg, -(2 ** 30)): | |||
# always run all_reduce_mean first except add_update | |||
grad = ( | |||
all_reduce_sum(grad, "grad_" + str(get_group_id())) | |||
/ get_world_size() | |||
) | |||
with opr_priority_scope(cg, (1 << 30) - priority): | |||
with opr_priority_scope(cg, -(2 ** 31)): | |||
# always run add_update first | |||
grad_update = add_update(param.grad, grad) | |||
else: | |||
grad_update = add_update(param.grad, grad) | |||
@@ -66,6 +66,8 @@ bool _config::set_comp_graph_option( | |||
SET_CG_OPTION(graph_opt.jit); | |||
SET_CG_OPTION(graph_opt.tensorrt); | |||
SET_CG_OPTION(graph_opt_level); | |||
SET_CG_OPTION(allreduce_pack_max_size); | |||
SET_CG_OPTION(allreduce_pack_ignore_first); | |||
SET_CG_OPTION(var_sanity_check_first_run); | |||
SET_CG_OPTION(no_profiling_on_shape_change); | |||
SET_CG_OPTION(allocate_static_mem_after_graph_compile); | |||
@@ -1,3 +1,7 @@ | |||
%{ | |||
#include "megbrain/gopt/framework.h" | |||
%} | |||
%inline { | |||
SymbolVarArray _get_owner_opr_inputs(SymbolVar var) { | |||
@@ -35,5 +39,17 @@ | |||
} | |||
return mgb::cg::replace_oprs(vars, oprmap); | |||
} | |||
void _set_priority_to_id(const SymbolVarArray& dest_vars) { | |||
auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { | |||
if (opr->node_prop().attribute().priority == 0) { | |||
opr->node_prop().attribute().priority = opr->id(); | |||
} | |||
}; | |||
mgb::cg::DepOprIter dep_iter{on_opr}; | |||
for (const SymbolVar& var : dest_vars) { | |||
dep_iter.add(var); | |||
} | |||
} | |||
} | |||
// vim: ft=swig foldmethod=marker foldmarker=f{{{,f}}} |
@@ -441,12 +441,22 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||
optimizer.verbosity(options().log_level); | |||
optimizer.enable_check_result(options().graph_opt_level < 0); | |||
if (sopr_stat.has_virtual_grad) { | |||
if (need_opt) | |||
if (need_opt) { | |||
#if MGB_ENABLE_OPR_MM | |||
optimizer.add_pass<gopt::PackAllReduceScanPass>(); | |||
#endif | |||
optimizer.add_preset_passes(false, nullptr, &options()); | |||
} | |||
optimizer.add_pass<gopt::ExpandVirtualGradPass>(); | |||
} | |||
if (need_opt) | |||
if (need_opt) { | |||
optimizer.add_preset_passes(true, nullptr, &options()); | |||
#if MGB_ENABLE_OPR_MM | |||
if (sopr_stat.has_virtual_grad) { | |||
optimizer.add_pass<gopt::PackAllReduceReplacePass>(); | |||
} | |||
#endif | |||
} | |||
optimizer.apply_inplace(dest_vars); | |||
} | |||
#endif | |||
@@ -328,6 +328,18 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||
int16_t graph_opt_level = 2; | |||
/*! | |||
* max size of allreduce packs in MB | |||
* set this option to zero to disable PackAllReducePass | |||
*/ | |||
int16_t allreduce_pack_max_size = 0; | |||
/*! | |||
* do not pack the first n allreduces | |||
* PackAllReducePass disabled if allreduce_pack_max_size is zero | |||
*/ | |||
int16_t allreduce_pack_ignore_first = 2; | |||
/*! | |||
* set logging level, larger number means more verbose | |||
* 0: no log info | |||
* 1: static memory allocation status | |||
@@ -183,7 +183,6 @@ SymbolVarArray replace_oprs( | |||
SymbolVarArray replace_vars_comp_graph( | |||
const SymbolVarArray &dest, ComputingGraph* new_graph); | |||
SymbolVarArray find_h2d(const SymbolVarArray& dest); | |||
/*! | |||
@@ -17,6 +17,7 @@ | |||
#include "megbrain/opr/utility.h" | |||
#include "megbrain/serialization/serializer.h" | |||
#include "megbrain/serialization/opr_shallow_copy.h" | |||
#include "../../core/impl/graph/cg_impl.h" | |||
using namespace mgb; | |||
using namespace gopt; | |||
@@ -657,4 +658,309 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const { | |||
rewriter.apply_inplace(); | |||
} | |||
#if MGB_ENABLE_OPR_MM | |||
#include "megbrain/opr/collective_comm.h" | |||
/* ======================= PackAllReduceScanPass ====================== */ | |||
const char* PackAllReduceScanPass::name() const { | |||
return "pack_allreduce_scan"; | |||
} | |||
void PackAllReduceScanPass::apply(OptState& opt) const { | |||
auto comp_graph = opt.graph().comp_graph(); | |||
if (comp_graph->options().allreduce_pack_max_size == 0) return; | |||
auto cb_scan = [this] (OperatorNodeBase* opr) { | |||
if (check_pattern(opr)) { | |||
auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||
VarNode* target = comm.input(0)->owner_opr()->input(0); | |||
// only pack allreduces of grads of the same target | |||
// in case two allreduces depend on each other | |||
size_t id = target->id(); | |||
uint64_t hash = XXHash().update(&id, sizeof(size_t)).digest(); | |||
comm.set_pack_hash(hash); | |||
} | |||
}; | |||
opt.graph().iter(cb_scan); | |||
} | |||
bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { | |||
if (!opr->same_type<opr::CollectiveComm>()) return false; | |||
auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||
if (comm.param().mode != opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM) return false; | |||
if (comm.input().size() != 1) return false; | |||
auto grad = comm.input(0)->owner_opr(); | |||
if (!grad->same_type<opr::VirtualGrad>()) return false; | |||
if (grad->input().size() != 2 or grad->output().size() != 1) return false; | |||
auto param = grad->input(1)->owner_opr(); | |||
if (!param->same_type<opr::SharedDeviceTensor>() and | |||
!param->same_type<opr::VolatileSharedDeviceTensor>()) return false; | |||
if (param->input().size() != 0) return false; | |||
return true; | |||
} | |||
/* ======================= PackAllReduceReplacePass ====================== */ | |||
const char* PackAllReduceReplacePass::name() const { | |||
return "pack_allreduce_replace"; | |||
} | |||
class PackAllReduceReplacePass::GroupInfo { | |||
public: | |||
GroupInfo(int _device, DType _dtype, | |||
size_t _nr_devices, bool _is_root, int _rank, | |||
std::shared_ptr<opr::GroupClient> _group_client, | |||
const std::string& _backend); | |||
uint64_t hash(uint64_t extra) const; | |||
int device; | |||
DType dtype; | |||
size_t nr_devices; | |||
bool is_root; | |||
int rank; | |||
std::shared_ptr<opr::GroupClient> group_client; | |||
std::string backend; | |||
}; | |||
PackAllReduceReplacePass::GroupInfo::GroupInfo( | |||
int _device, DType _dtype, | |||
size_t _nr_devices, bool _is_root, int _rank, | |||
std::shared_ptr<opr::GroupClient> _group_client, | |||
const std::string& _backend) : | |||
device(_device), dtype(_dtype), | |||
nr_devices(_nr_devices), is_root(_is_root), rank(_rank), | |||
group_client(_group_client), backend(_backend) { | |||
} | |||
uint64_t PackAllReduceReplacePass::GroupInfo::hash(uint64_t extra) const { | |||
DTypeEnum ev = dtype.enumv(); | |||
const std::string& server_addr = group_client->get_addr(); | |||
return XXHash() | |||
.update(&extra, sizeof(uint64_t)) | |||
.update(&device, sizeof(int)) | |||
.update(&ev, sizeof(DTypeEnum)) | |||
.update(&nr_devices, sizeof(size_t)) | |||
.update(&is_root, sizeof(bool)) | |||
.update(&rank, sizeof(int)) | |||
.update(server_addr.c_str(), server_addr.size()) | |||
.update(backend.c_str(), backend.size()) | |||
.digest(); | |||
} | |||
uint64_t PackAllReduceReplacePass::collect_groups(OperatorNodeBase* opr, | |||
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info, | |||
ThinHashMap<uint64_t, cg::OprNodeArray>& groups) { | |||
// check CollectiveComm oprs that have been marked in PackAllReduceScanPass | |||
if (!opr->same_type<opr::CollectiveComm>()) return 0; | |||
opr::CollectiveComm& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||
if (comm.pack_hash() == 0) return 0; // pack_hash not set | |||
VarNode* var = comm.input(0); | |||
auto info = std::make_shared<GroupInfo>( | |||
var->comp_node().locator().device, | |||
var->dtype(), | |||
comm.nr_devices(), | |||
comm.is_root(), | |||
comm.rank(), | |||
comm.group_client(), | |||
comm.backend() | |||
); | |||
uint64_t hash = info->hash(comm.pack_hash()); | |||
if (group_info.find(hash) == group_info.end()) { | |||
group_info.emplace(hash, info); | |||
} | |||
groups[hash].push_back(opr); | |||
return hash; | |||
} | |||
void PackAllReduceReplacePass::divide_packs( | |||
const ThinHashMap<uint64_t, cg::OprNodeArray>& groups, | |||
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs, | |||
size_t max_size) { | |||
cg::OprNodeArray pack; | |||
size_t sum = 0; | |||
for (auto it : groups) { | |||
uint64_t hash = it.first; | |||
const cg::OprNodeArray& group = it.second; | |||
for (size_t i = 0; i < group.size(); i++) { | |||
OperatorNodeBase* opr = group[i]; | |||
VarNode* var = opr->input(0); | |||
const TensorShape* shape = var->owner_graph() | |||
->static_infer_manager().infer_shape_fallible(var); | |||
if (shape == nullptr) continue; | |||
pack.push_back(opr); | |||
sum += var->dtype().size(shape->total_nr_elems()); | |||
if (sum >= max_size) { | |||
if (pack.size() > 1) packs[hash].push_back(pack); | |||
pack.clear(); | |||
sum = 0; | |||
} | |||
} | |||
if (pack.size() > 1) packs[hash].push_back(pack); | |||
pack.clear(); | |||
sum = 0; | |||
} | |||
} | |||
void PackAllReduceReplacePass::insert_packed_oprs( | |||
size_t pack_id, | |||
const cg::OprNodeArray& pack, | |||
std::shared_ptr<GroupInfo> info, | |||
ThinHashMap<VarNode*, VarNode*>& replace_map, int priority) { | |||
// set priority | |||
mgb_assert(pack.size() > 0); | |||
auto graph = pack[0]->owner_graph(); | |||
auto on_opr_inserted = [priority] (const cg::event::OprInserted& event) { | |||
event.opr->node_prop().attribute().priority = priority; | |||
}; | |||
auto handler = graph->event().register_receiver<cg::event::OprInserted>(on_opr_inserted); | |||
// flatten inputs and record shapes and partition | |||
std::vector<SymbolVar> shapes; | |||
SymbolVarArray flattens; | |||
SymbolVarArray partition; | |||
for (size_t i = 0; i < pack.size(); i++) { | |||
VarNode* var = pack[i]->input(0); | |||
auto shape = opr::GetVarShape::make(SymbolVar(var)); | |||
shapes.push_back(shape); | |||
SymbolVar flatten = SymbolVar(var).flatten(); | |||
flattens.push_back(flatten); | |||
partition.push_back(opr::Reduce::make(shape, {opr::Reduce::Mode::PRODUCT, 0})); | |||
} | |||
// concat | |||
SymbolVar concat = opr::Concat::make(flattens, 0); | |||
// allreduce | |||
std::string key = ssprintf("grad_pack_%zu", pack_id); | |||
auto param = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||
SymbolVar allreduce = opr::CollectiveComm::make({concat}, graph, | |||
key, info->nr_devices, info->is_root, info->rank, | |||
info->group_client, param, info->dtype, info->backend)[0]; | |||
// split according to recorded partition | |||
SymbolVarArray splits = opr::Split::make(allreduce, | |||
opr::Split::Options::make_partition(0, partition)); | |||
// reshape and insert results into replace_map | |||
mgb_assert(pack.size() == splits.size()); | |||
for (size_t i = 0; i < pack.size(); i++) { | |||
VarNode* reshape = splits[i].reshape(shapes[i]).node(); | |||
replace_map[pack[i]->output(0)] = reshape; | |||
} | |||
} | |||
void PackAllReduceReplacePass::apply(OptState& opt) const { | |||
// get graph options | |||
auto comp_graph = opt.graph().comp_graph(); | |||
size_t max_size = comp_graph->options().allreduce_pack_max_size * 1024 * 1024; | |||
size_t ignore_first = comp_graph->options().allreduce_pack_ignore_first; | |||
if (max_size == 0) return; | |||
// get topo order | |||
auto& topo_sorter = static_cast<cg::ComputingGraphImpl*>(comp_graph)->topo_sorter(); | |||
cg::CompSeqExtraInfo extra_info; | |||
VarNodeArray endpoints = to_var_node_array(opt.graph().endpoint_vars()); | |||
const cg::OprNodeArray* seq = topo_sorter.get_comp_seq(extra_info, endpoints); | |||
topo_sorter.restore_opr_prop(); | |||
// collect allreduce groups from topo sequence | |||
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | |||
ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||
for (size_t i = 0; i < seq->size(); i++) { | |||
if (seq->at(i)->same_type<opr::CollectiveComm>()) { | |||
// ignore the first several allreduces | |||
if (ignore_first > 0) { | |||
--ignore_first; | |||
} else { | |||
collect_groups(seq->at(i), group_info, groups); | |||
} | |||
} | |||
} | |||
// divide groups into packs | |||
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>> packs; | |||
divide_packs(groups, packs, max_size); | |||
// make sure that oprs inserted in this pass (reshape, concat, allreduce, | |||
// split, reshape) have higher priority than existing operators | |||
int priority = -seq->size() - 100; | |||
// insert packed operators and generate replace_map | |||
ThinHashMap<VarNode*, VarNode*> replace_map; | |||
size_t pack_id = 0; | |||
for (auto it : packs) { | |||
uint64_t hash = it.first; | |||
for (auto pack : it.second) { | |||
opt.call_with_opr(pack[0], [&]() { | |||
insert_packed_oprs(pack_id, pack, group_info[hash], replace_map, priority); | |||
}, OprPropertyFlag::NONE); | |||
pack_id += 1; | |||
} | |||
} | |||
// replace vars | |||
auto rewriter = opt.graph().make_rewriter(); | |||
auto cb_replace = [&](OperatorNodeBase* opr) { | |||
for (auto i : opr->input()) { | |||
auto iter = replace_map.find(i); | |||
if (iter != replace_map.end()) { | |||
rewriter.replace_var(i, iter->second, nullptr); | |||
} | |||
} | |||
rewriter.auto_replace_outputs(opr); | |||
}; | |||
opt.graph().iter(cb_replace); | |||
rewriter.apply_inplace(); | |||
} | |||
#else | |||
/* ======================= PackAllReduceScanPass ====================== */ | |||
const char* PackAllReduceScanPass::name() const { | |||
return "pack_allreduce_scan"; | |||
} | |||
void PackAllReduceScanPass::apply(OptState& opt) const { | |||
} | |||
bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { | |||
return true; | |||
} | |||
/* ======================= PackAllReduceReplacePass ====================== */ | |||
const char* PackAllReduceReplacePass::name() const { | |||
return "pack_allreduce_replace"; | |||
} | |||
void PackAllReduceReplacePass::apply(OptState& opt) const {} | |||
uint64_t PackAllReduceReplacePass::collect_groups( | |||
OperatorNodeBase* opr, | |||
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info, | |||
ThinHashMap<uint64_t, cg::OprNodeArray>& groups) { | |||
return 0; | |||
} | |||
void PackAllReduceReplacePass::divide_packs( | |||
const ThinHashMap<uint64_t, cg::OprNodeArray>& groups, | |||
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs, | |||
size_t max_size) { | |||
} | |||
void PackAllReduceReplacePass::insert_packed_oprs( | |||
size_t pack_id, | |||
const cg::OprNodeArray& pack, | |||
std::shared_ptr<GroupInfo> info, | |||
ThinHashMap<VarNode*, VarNode*>& replace_map, int priority) { | |||
} | |||
#endif // MGB_ENABLE_OPR_MM | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -11,6 +11,8 @@ | |||
#pragma once | |||
#include <vector> | |||
#include "megbrain/gopt/framework.h" | |||
namespace mgb { | |||
@@ -90,6 +92,45 @@ namespace gopt { | |||
void apply(OptState& opt) const override; | |||
}; | |||
//! scan allreduces of param grads | |||
class PackAllReduceScanPass final : public Pass { | |||
public: | |||
const char* name() const override; | |||
void apply(OptState& opt) const override; | |||
private: | |||
// check pattern param -> grad -> allreduce | |||
static bool check_pattern(OperatorNodeBase* opr); | |||
}; | |||
//! pack allreduces of param grads | |||
class PackAllReduceReplacePass final : public Pass { | |||
public: | |||
class GroupInfo; | |||
const char* name() const override; | |||
void apply(OptState& opt) const override; | |||
// collect allreduces and divide into groups | |||
static uint64_t collect_groups( | |||
OperatorNodeBase* opr, | |||
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info, | |||
ThinHashMap<uint64_t, cg::OprNodeArray>& groups); | |||
// divide groups into packs, max_size in MB | |||
static void divide_packs( | |||
const ThinHashMap<uint64_t, cg::OprNodeArray>& groups, | |||
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs, | |||
size_t max_size); | |||
// insert packed operators and update replace_map | |||
static void insert_packed_oprs( | |||
size_t pack_id, | |||
const cg::OprNodeArray& pack, | |||
std::shared_ptr<GroupInfo> info, | |||
ThinHashMap<VarNode*, VarNode*>& replace_map, int priority); | |||
}; | |||
} // namespace gopt | |||
} // namespace mgb | |||
@@ -14,6 +14,7 @@ | |||
#include "megbrain/gopt/basic_arith.h" | |||
#include "megbrain/gopt/misc.h" | |||
#include "megbrain/opr/basic_arith_wrapper.h" | |||
#include "megbrain/opr/blas.h" | |||
#include "megbrain/opr/cond.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/opr/utility.h" | |||
@@ -410,4 +411,322 @@ TEST_PASS(RemoveRedundantTypeCvtPass, Basic) { | |||
check(x_q8_q8, x_q8_fp32_q8_); | |||
} | |||
#if MGB_ENABLE_OPR_MM | |||
#include "megbrain/opr/collective_comm.h" | |||
#include "../../opr-mm/test/mock_client.h" | |||
TEST_PASS(PackAllReduceScanPass, Basic) { | |||
auto graph = ComputingGraph::make(); | |||
graph->options().allreduce_pack_max_size = 5000; | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto cn = CompNode::load("gpux"); | |||
auto dev_x0 = std::make_shared<DeviceTensorND>(cn, TensorShape{3, 5}); | |||
auto dev_x1 = std::make_shared<DeviceTensorND>(cn, TensorShape{4, 6}); | |||
auto dev_y0 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}); | |||
auto dev_y1 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}); | |||
auto x0 = opr::SharedDeviceTensor::make(*graph, dev_x0); | |||
auto x1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_x1); | |||
auto y0 = opr::SharedDeviceTensor::make(*graph, dev_y0); | |||
auto y1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_y1); | |||
auto grad0 = opr::VirtualGrad::make(y0, x0); | |||
auto grad1 = opr::VirtualGrad::make(y0, x1); | |||
auto grad2 = opr::VirtualGrad::make(y1, x0); | |||
auto grad3 = opr::VirtualGrad::make(y1, x1); | |||
auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||
auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), | |||
"grad0", 2, 0, 0, client, mode)[0]; | |||
auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), | |||
"grad1", 2, 0, 0, client, mode)[0]; | |||
auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), | |||
"grad2", 2, 0, 0, client, mode)[0]; | |||
auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), | |||
"grad3", 2, 0, 0, client, mode)[0]; | |||
gopt::GraphOptimizer() | |||
.add_pass<gopt::PackAllReduceScanPass>() | |||
.apply({{comm0, comm1, comm2, comm3}}); | |||
auto get_hash = [] (const SymbolVar& symvar) { | |||
cg::OperatorNodeBase* opr = symvar.node()->owner_opr(); | |||
return opr->cast_final_safe<opr::CollectiveComm>().pack_hash(); | |||
}; | |||
uint64_t hash0 = get_hash(comm0); | |||
uint64_t hash1 = get_hash(comm1); | |||
uint64_t hash2 = get_hash(comm2); | |||
uint64_t hash3 = get_hash(comm3); | |||
ASSERT_EQ(hash0, hash1); | |||
ASSERT_EQ(hash2, hash3); | |||
ASSERT_NE(hash0, hash2); | |||
} | |||
TEST_PASS(PackAllReduceReplacePass, CollectGroups) { | |||
REQUIRE_GPU(2); | |||
auto cns = load_multiple_xpus(2); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 2; | |||
auto cli0 = std::make_shared<test::MockGroupClient>("mock_addr0"); | |||
auto cli1 = std::make_shared<test::MockGroupClient>("mock_addr1"); | |||
using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; | |||
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | |||
ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||
auto add_opr = [&] (const CompNode& cn, TensorShape shape, const DType& dt, | |||
std::shared_ptr<test::MockGroupClient> client, uint64_t extra_hash) { | |||
auto dev0 = std::make_shared<DeviceTensorND>(cn, shape, dt); | |||
auto wrt = opr::SharedDeviceTensor::make(*graph, dev0); | |||
auto dev1 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}, dt); | |||
auto target = opr::SharedDeviceTensor::make(*graph, dev1); | |||
auto grad = opr::VirtualGrad::make(target, wrt); | |||
auto comm = opr::CollectiveComm::make( | |||
{grad}, graph.get(), "key", 2, 0, 0, client, | |||
opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0] | |||
.node()->owner_opr(); | |||
comm->cast_final_safe<opr::CollectiveComm>().set_pack_hash(extra_hash); | |||
return gopt::PackAllReduceReplacePass::collect_groups(comm, group_info, groups); | |||
}; | |||
uint64_t hash0 = add_opr(cns[0], TensorShape{1, 3}, dtype::Float32{}, cli0, 1); | |||
uint64_t hash1 = add_opr(cns[0], TensorShape{2, 4}, dtype::Float32{}, cli0, 1); // same | |||
uint64_t hash2 = add_opr(cns[1], TensorShape{3, 5}, dtype::Float32{}, cli0, 1); // comp_node | |||
uint64_t hash3 = add_opr(cns[0], TensorShape{4, 6}, dtype::Float16{}, cli0, 1); // dtype | |||
uint64_t hash4 = add_opr(cns[0], TensorShape{5, 7}, dtype::Float32{}, cli1, 1); // client | |||
uint64_t hash5 = add_opr(cns[0], TensorShape{6, 8}, dtype::Float32{}, cli0, 2); // extra_hash | |||
ASSERT_EQ(hash0, hash1); | |||
std::set<uint64_t> s; | |||
s.insert(hash0); | |||
s.insert(hash1); | |||
s.insert(hash2); | |||
s.insert(hash3); | |||
s.insert(hash4); | |||
s.insert(hash5); | |||
ASSERT_EQ(5, s.size()); | |||
ASSERT_EQ(1, group_info.count(hash0)); | |||
ASSERT_EQ(1, group_info.count(hash1)); | |||
ASSERT_EQ(1, group_info.count(hash2)); | |||
ASSERT_EQ(1, group_info.count(hash3)); | |||
ASSERT_EQ(1, group_info.count(hash4)); | |||
ASSERT_EQ(1, group_info.count(hash5)); | |||
ASSERT_EQ(2, groups[hash0].size()); | |||
ASSERT_EQ(2, groups[hash1].size()); | |||
ASSERT_EQ(1, groups[hash2].size()); | |||
ASSERT_EQ(1, groups[hash3].size()); | |||
ASSERT_EQ(1, groups[hash4].size()); | |||
ASSERT_EQ(1, groups[hash5].size()); | |||
} | |||
TEST_PASS(PackAllReduceReplacePass, DividePacks) { | |||
auto cn = CompNode::load("gpux"); | |||
auto graph = ComputingGraph::make(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||
ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>> packs; | |||
auto insert_opr = [&] (size_t size) { | |||
auto dev = std::make_shared<DeviceTensorND>(cn, TensorShape{size / sizeof(float)}); | |||
auto sd = opr::SharedDeviceTensor::make(*graph, dev); | |||
auto symvar = opr::CollectiveComm::make({sd}, graph.get(), | |||
"key", 2, 0, 0, client, mode)[0]; | |||
auto opr = symvar.node()->owner_opr(); | |||
auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||
comm.set_pack_hash(1); | |||
return opr; | |||
}; | |||
auto pack_size = [&] (cg::OprNodeArray& pack) { | |||
size_t sum = 0; | |||
for (size_t i = 0; i < pack.size(); i++) { | |||
auto var = pack[i]->input(0); | |||
sum += var->dtype().size(var->shape().total_nr_elems()); | |||
} | |||
return sum; | |||
}; | |||
groups[0].push_back(insert_opr(100)); // group0, pack0, size=1100 | |||
groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100 | |||
groups[0].push_back(insert_opr(400)); // group0, pack0, size=1100 | |||
groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100 | |||
groups[0].push_back(insert_opr(500)); // group0, pack1, size=800 | |||
groups[0].push_back(insert_opr(200)); // group0, pack1, size=800 | |||
groups[0].push_back(insert_opr(100)); // group0, pack1, size=800 | |||
groups[1].push_back(insert_opr(100)); // group1, pack0, size=900 | |||
groups[1].push_back(insert_opr(400)); // group1, pack0, size=900 | |||
groups[1].push_back(insert_opr(300)); // group1, pack0, size=900 | |||
groups[1].push_back(insert_opr(100)); // group1, pack0, size=900 | |||
gopt::PackAllReduceReplacePass::divide_packs(groups, packs, 1000); | |||
ASSERT_EQ(2, packs.size()); | |||
ASSERT_EQ(2, packs[0].size()); | |||
ASSERT_EQ(4, packs[0][0].size()); | |||
ASSERT_EQ(1100, pack_size(packs[0][0])); | |||
ASSERT_EQ(3, packs[0][1].size()); | |||
ASSERT_EQ(800, pack_size(packs[0][1])); | |||
ASSERT_EQ(1, packs[1].size()); | |||
ASSERT_EQ(4, packs[1][0].size()); | |||
ASSERT_EQ(900, pack_size(packs[1][0])); | |||
} | |||
TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { | |||
auto cn = CompNode::load("gpux"); | |||
auto graph = ComputingGraph::make(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||
size_t nr_devices = 2; | |||
uint32_t rank = 0; | |||
uint32_t root = 0; | |||
using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; | |||
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | |||
ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||
auto insert_opr = [&] (const TensorShape& shape) { | |||
auto dev = std::make_shared<DeviceTensorND>(cn, shape); | |||
auto sd = opr::SharedDeviceTensor::make(*graph, dev); | |||
auto symvar = opr::CollectiveComm::make({sd}, graph.get(), | |||
"key", nr_devices, rank, root, client, mode)[0]; | |||
auto opr = symvar.node()->owner_opr(); | |||
auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||
comm.set_pack_hash(1); | |||
gopt::PackAllReduceReplacePass::collect_groups(opr, group_info, groups); | |||
return symvar; | |||
}; | |||
auto shape_x = TensorShape{100, 200}; | |||
auto shape_y = TensorShape{200, 400}; | |||
auto x = insert_opr(shape_x); | |||
auto y = insert_opr(shape_y); | |||
ASSERT_EQ(1, group_info.size()); | |||
ASSERT_EQ(1, groups.size()); | |||
auto info = group_info.begin()->second; | |||
auto pack = groups.begin()->second; | |||
size_t pack_id = 0; | |||
ThinHashMap<VarNode*, VarNode*> replace_map; | |||
gopt::PackAllReduceReplacePass::insert_packed_oprs(pack_id, pack, info, replace_map, -1); | |||
auto grad_x = SymbolVar(x.node()->owner_opr()->input(0)); | |||
auto grad_y = SymbolVar(y.node()->owner_opr()->input(0)); | |||
auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0); | |||
std::string key = ssprintf("grad_pack_%zu", pack_id); | |||
auto allreduce = opr::CollectiveComm::make({concat}, graph.get(), | |||
key, nr_devices, rank, root, client, mode)[0]; | |||
std::vector<size_t> partition; | |||
partition.push_back(shape_x.total_nr_elems()); | |||
partition.push_back(shape_y.total_nr_elems()); | |||
auto splits = opr::Split::make(allreduce, | |||
opr::Split::Options::make_partition(allreduce, 0, partition)); | |||
ASSERT_EQ(2, splits.size()); | |||
auto dest_x = splits[0].reshape(shape_x); | |||
auto dest_y = splits[1].reshape(shape_y); | |||
ASSERT_EQ(2, replace_map.size()); | |||
ASSERT_TRUE(replace_map.count(x.node()) > 0); | |||
ASSERT_EQ(replace_map.at(x.node()), dest_x.node()); | |||
ASSERT_TRUE(replace_map.count(y.node()) > 0); | |||
ASSERT_EQ(replace_map.at(y.node()), dest_y.node()); | |||
} | |||
TEST_PASS(PackAllReduceReplacePass, Equivalence) { | |||
REQUIRE_GPU(2); | |||
auto cns = load_multiple_xpus(2); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto build_graph = [&] (uint32_t rank, std::shared_ptr<ComputingGraph> graph, | |||
SymbolVarArray& array) { | |||
HostTensorGenerator<> gen; | |||
auto cn = cns[rank]; | |||
auto host_x = gen({1, 1000}); | |||
auto host_y = gen({1000, 1}); | |||
auto dev_x = std::make_shared<DeviceTensorND>(cn); | |||
auto dev_y = std::make_shared<DeviceTensorND>(cn); | |||
dev_x->copy_from(*host_x).sync(); | |||
dev_y->copy_from(*host_y).sync(); | |||
auto x = opr::SharedDeviceTensor::make(*graph, dev_x); | |||
auto y = opr::VolatileSharedDeviceTensor::make(*graph, dev_y); | |||
auto loss = opr::MatrixMul::make(x, y).flatten(); | |||
auto grad_x = opr::VirtualGrad::make(loss, x); | |||
auto grad_y = opr::VirtualGrad::make(loss, y); | |||
using Mode = opr::CollectiveComm::Param::Mode; | |||
bool is_root = (rank == 0); | |||
auto reduced_x = opr::CollectiveComm::make({grad_x}, graph.get(), | |||
"x", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; | |||
auto reduced_y = opr::CollectiveComm::make({grad_y}, graph.get(), | |||
"y", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; | |||
graph->options().allreduce_pack_max_size = 5000; | |||
graph->options().allreduce_pack_ignore_first = 0; | |||
auto dest_vars = gopt::GraphOptimizer{} | |||
.add_pass<gopt::PackAllReduceScanPass>() | |||
.add_pass<gopt::PackAllReduceReplacePass>() | |||
.apply({{reduced_x, reduced_y}}).endpoint_vars(); | |||
array.emplace_back(reduced_x); | |||
array.emplace_back(reduced_y); | |||
array.emplace_back(dest_vars[0]); | |||
array.emplace_back(dest_vars[1]); | |||
}; | |||
auto run = [&] (uint32_t rank) { | |||
auto graph = ComputingGraph::make(); | |||
SymbolVarArray array; | |||
build_graph(rank, graph, array); | |||
HostTensorND host_reduced_x, host_reduced_y, host_dest_0, host_dest_1; | |||
graph->options().allreduce_pack_max_size = 0; | |||
auto func = graph->compile({make_callback_copy(array[0], host_reduced_x), | |||
make_callback_copy(array[1], host_reduced_y), | |||
make_callback_copy(array[2], host_dest_0), | |||
make_callback_copy(array[3], host_dest_1)}); | |||
func->execute(); | |||
MGB_ASSERT_TENSOR_EQ(host_reduced_x, host_dest_0); | |||
MGB_ASSERT_TENSOR_EQ(host_reduced_y, host_dest_1); | |||
}; | |||
std::thread t0(run, 0); | |||
std::thread t1(run, 1); | |||
t0.join(); | |||
t1.join(); | |||
} | |||
#endif // MGB_ENABLE_OPR_MM | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -461,16 +461,7 @@ void CollectiveComm::opr_register() { | |||
m_rank = reg_info.rank; | |||
m_root = reg_info.root_rank; | |||
MegRayCommunicatorBuilder* builder; | |||
{ | |||
static std::mutex user_data_mtx; | |||
std::unique_lock<std::mutex> lk(user_data_mtx); | |||
builder = owner_graph()->options().user_data | |||
.get_user_data_or_create<MegRayCommunicatorBuilder>(); | |||
} | |||
m_megray_comm = builder->get_megray_comm( | |||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||
reg_info.hash, m_key, m_nr_devices, m_rank, | |||
get_megray_backend(m_backend), m_group_client); | |||
@@ -736,13 +727,15 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm( | |||
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||
const OperatorNodeConfig& config) { | |||
auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>(); | |||
return opr::CollectiveComm::make( | |||
auto new_opr = CollectiveComm::make( | |||
to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), | |||
opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), | |||
opr.group_client(), opr.dev_buffers(), opr.param(), | |||
opr.dtype(), opr.backend(), config)[0] | |||
.node() | |||
->owner_opr(); | |||
new_opr->cast_final_safe<opr::CollectiveComm>().set_pack_hash(opr.pack_hash()); | |||
return new_opr; | |||
} | |||
MGB_REG_OPR_SHALLOW_COPY(CollectiveComm, opr_shallow_copy_collective_mm); | |||
@@ -54,13 +54,7 @@ void RemoteSend::scn_do_execute() { | |||
auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false, | |||
comp_node.get_uid()); | |||
auto megray_comm_builder = | |||
owner_graph() | |||
->options() | |||
.user_data | |||
.get_user_data_or_create<MegRayCommunicatorBuilder>(); | |||
m_megray_comm = megray_comm_builder->get_megray_comm( | |||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||
reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | |||
m_init = true; | |||
} | |||
@@ -158,13 +152,7 @@ void RemoteRecv::scn_do_execute() { | |||
m_peer.key, 2, false, 1, | |||
comp_node.get_uid()); | |||
auto megray_comm_builder = | |||
owner_graph() | |||
->options() | |||
.user_data | |||
.get_user_data_or_create<MegRayCommunicatorBuilder>(); | |||
m_megray_comm = megray_comm_builder->get_megray_comm( | |||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||
reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | |||
m_init = true; | |||
} | |||
@@ -14,8 +14,8 @@ | |||
using namespace mgb; | |||
using namespace opr; | |||
bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { | |||
std::unique_lock<std::mutex> lk(m_mtx); | |||
bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { | |||
std::unique_lock<std::mutex> lk(m_map_mtx); | |||
auto it = m_megray_comms.find(hash); | |||
if (it != m_megray_comms.end()) { | |||
comm = it->second; | |||
@@ -24,27 +24,37 @@ bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Comm | |||
return false; | |||
} | |||
void MegRayCommunicatorBuilder::emplace(uint64_t hash, | |||
void MegRayCommBuilder::emplace(uint64_t hash, | |||
std::shared_ptr<MegRay::Communicator> comm) { | |||
std::unique_lock<std::mutex> lk(m_mtx); | |||
std::unique_lock<std::mutex> lk(m_map_mtx); | |||
m_megray_comms.emplace(hash, comm); | |||
} | |||
std::shared_ptr<MegRay::Communicator> MegRayCommunicatorBuilder::get_megray_comm( | |||
std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||
uint64_t hash, std::string key, uint32_t size, uint32_t rank, | |||
MegRay::Backend backend, | |||
std::shared_ptr<mgb::opr::GroupClient> group_client) { | |||
{ | |||
// singleton pattern | |||
std::unique_lock<std::mutex> lk(sm_instance_mtx); | |||
if (sm_instance == nullptr) { | |||
sm_instance = new MegRayCommBuilder(); | |||
} | |||
} | |||
std::shared_ptr<MegRay::Communicator> comm; | |||
if (!find(hash, comm)) { | |||
if (!sm_instance->find(hash, comm)) { | |||
comm = MegRay::get_communicator(size, rank, backend); | |||
auto uid = comm->get_uid(); | |||
auto uids = group_client->gather_uid(uid, key, size, rank); | |||
mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK); | |||
emplace(hash, comm); | |||
sm_instance->emplace(hash, comm); | |||
} | |||
return comm; | |||
} | |||
MGB_TYPEINFO_OBJ_IMPL(MegRayCommunicatorBuilder); | |||
MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr; | |||
std::mutex MegRayCommBuilder::sm_instance_mtx; | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -81,6 +81,10 @@ public: | |||
return m_group_client; | |||
} | |||
void set_pack_hash(uint64_t hash) { m_pack_hash = hash; } | |||
uint64_t pack_hash() const { return m_pack_hash; } | |||
std::shared_ptr<MegRay::Context> megray_ctx() const { | |||
return m_megray_ctx; | |||
} | |||
@@ -123,6 +127,9 @@ private: | |||
// whose shape infer should be disabled *during* static infer phase. | |||
bool m_enable_shape_infer = false; | |||
//! set in PackAllReduceScanPass and used in PackAllReduceReplacePass | |||
uint64_t m_pack_hash = 0; | |||
std::shared_ptr<MegRay::Context> m_megray_ctx; | |||
std::shared_ptr<MegRay::Communicator> m_megray_comm; | |||
bool m_init = false; | |||
@@ -126,6 +126,8 @@ class GroupClient { | |||
virtual ~GroupClient() = default; | |||
public: | |||
virtual const std::string& get_addr() const = 0; | |||
virtual GroupManager::RegisterInfo opr_register(const std::string& key, | |||
size_t nr_devices, | |||
bool is_root, int rank, | |||
@@ -23,18 +23,19 @@ namespace opr { | |||
/*! | |||
* gather MegRay unique ids and build communicator, use hash for deduplication | |||
*/ | |||
class MegRayCommunicatorBuilder final : public mgb::UserDataContainer::UserData { | |||
MGB_TYPEINFO_OBJ_DECL; | |||
class MegRayCommBuilder { | |||
private: | |||
bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm); | |||
void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | |||
std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms; | |||
std::mutex m_mtx; | |||
std::mutex m_map_mtx; | |||
static MegRayCommBuilder* sm_instance; | |||
static std::mutex sm_instance_mtx; | |||
public: | |||
std::shared_ptr<MegRay::Communicator> get_megray_comm( | |||
static std::shared_ptr<MegRay::Communicator> get_megray_comm( | |||
uint64_t hash, std::string key, uint32_t size, uint32_t rank, | |||
MegRay::Backend backend, | |||
std::shared_ptr<mgb::opr::GroupClient> group_client); | |||
@@ -47,7 +47,7 @@ public: | |||
uint32_t group_barrier(uint32_t size, uint32_t rank) override; | |||
const std::string& get_addr() const { | |||
const std::string& get_addr() const override { | |||
return m_addr; | |||
} | |||
@@ -17,11 +17,10 @@ | |||
#include "megbrain/opr/utility.h" | |||
#include "megbrain/test/helper.h" | |||
#include "megbrain/graph.h" | |||
#include "mock_client.h" | |||
using namespace mgb; | |||
namespace { | |||
using Mode = opr::CollectiveComm::Param::Mode; | |||
SymbolVar make_all_reduce_output(const Mode mode, | |||
@@ -41,41 +40,6 @@ SymbolVarArray make_reduce_scatter_sum_output(const SymbolVarArray& inputs) { | |||
rdc, opr::Split::Options::make_average(0, inputs.size())); | |||
} | |||
class MockGroupClient final : public opr::GroupClient { | |||
public: | |||
~MockGroupClient() override = default; | |||
opr::GroupManager::RegisterInfo opr_register(const std::string& key, | |||
size_t nr_devices, | |||
bool is_root, int rank, | |||
uintptr_t stream) { | |||
return m_mgr.opr_register(key, nr_devices, is_root, rank, stream); | |||
} | |||
std::vector<std::string> gather_uid(const std::string& uid, | |||
const std::string& key, uint32_t size, uint32_t rank) { | |||
return m_mgr.gather_uid(uid, key, size, rank); | |||
} | |||
void set_output_shape(const std::string& key, | |||
const TensorShape& shape) override { | |||
m_mgr.set_output_shape(key, shape); | |||
} | |||
TensorShape get_output_shape(const std::string& key) override { | |||
return m_mgr.get_output_shape(key); | |||
} | |||
uint32_t group_barrier(uint32_t size, uint32_t rank) override { | |||
return m_mgr.group_barrier(size, rank); | |||
} | |||
private: | |||
opr::GroupManager m_mgr; | |||
}; | |||
} // namespace | |||
TEST(TestOprCollectiveComm, AllReduce) { | |||
REQUIRE_GPU(2); | |||
@@ -88,7 +52,7 @@ TEST(TestOprCollectiveComm, AllReduce) { | |||
auto host_x1 = gen({28, 28}); | |||
HostTensorND host_y0, host_y1, host_y_expect; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto graph = ComputingGraph::make(); | |||
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | |||
@@ -126,7 +90,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) { | |||
auto host_x1 = gen({28, 28}); | |||
HostTensorND host_y0, host_y1, host_y_expect; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto run_0 = [&]() { | |||
auto graph0 = ComputingGraph::make(); | |||
@@ -187,7 +151,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { | |||
HostTensorND host_y0, host_y1, host_y_expect; | |||
HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto run_0 = [&]() { // rank 0 | |||
auto graph0 = ComputingGraph::make(); | |||
@@ -268,7 +232,7 @@ TEST(TestOprCollectiveComm, AllGather) { | |||
auto host_x1 = gen({28, 28}); | |||
HostTensorND host_y0, host_y1, host_y_expect; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto graph = ComputingGraph::make(); | |||
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | |||
@@ -300,7 +264,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) { | |||
auto host_x1 = gen({28, 28}); | |||
HostTensorND host_y0, host_y1, host_y_expect; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto run_0 = [&]() { // rank 0 | |||
auto graph0 = ComputingGraph::make(); | |||
@@ -356,7 +320,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { | |||
HostTensorND host_out_grad0, host_out_grad1; | |||
HostTensorND host_out_grad0_expect, host_out_grad1_expect; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto run_0 = [&]() { // rank 0 | |||
auto graph0 = ComputingGraph::make(); | |||
@@ -438,7 +402,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSum) { | |||
auto host_x1 = gen({28, 28}); | |||
HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto graph = ComputingGraph::make(); | |||
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | |||
@@ -471,7 +435,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) { | |||
auto host_x1 = gen({8}); | |||
HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto run_0 = [&]() { // rank 0 | |||
auto graph0 = ComputingGraph::make(); | |||
@@ -528,7 +492,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) { | |||
HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | |||
HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto run_0 = [&]() { // rank 0 | |||
auto graph0 = ComputingGraph::make(); | |||
@@ -610,7 +574,7 @@ TEST(TestOprCollectiveComm, ReduceSum) { | |||
auto host_x1 = gen({28, 28}); | |||
HostTensorND host_y0, host_y1, host_y_expect; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto graph = ComputingGraph::make(); | |||
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | |||
@@ -641,7 +605,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) { | |||
auto host_x1 = gen({28, 28}); | |||
HostTensorND host_y0, host_y_expect; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto run_0 = [&]() { // rank 0 | |||
auto graph0 = ComputingGraph::make(); | |||
@@ -694,7 +658,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { | |||
HostTensorND host_y0, host_y0_expect, host_out_grad0, host_out_grad1; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto run_0 = [&]() { // rank 0 | |||
auto graph0 = ComputingGraph::make(); | |||
@@ -764,7 +728,7 @@ TEST(TestOprCollectiveComm, Broadcast) { | |||
auto host_x0 = gen({28, 28}); | |||
HostTensorND host_y0, host_y1, host_y_expect; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto graph = ComputingGraph::make(); | |||
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | |||
@@ -794,7 +758,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) { | |||
auto host_x0 = gen({28, 28}); | |||
HostTensorND host_y0, host_y1; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto run_0 = [&]() { // rank 0 | |||
auto graph0 = ComputingGraph::make(); | |||
@@ -840,7 +804,7 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) { | |||
HostTensorND host_y0, host_y1, host_out_grad, host_out_grad_expect; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto run_0 = [&]() { // rank 0 | |||
auto graph0 = ComputingGraph::make(); | |||
@@ -14,51 +14,14 @@ | |||
#include "megbrain/opr/utility.h" | |||
#include "megbrain/system.h" | |||
#include "megbrain/test/helper.h" | |||
#include "mock_client.h" | |||
#include <thread> | |||
using namespace mgb; | |||
using namespace opr; | |||
namespace { | |||
class MockGroupClient final : public opr::GroupClient { | |||
public: | |||
~MockGroupClient() override = default; | |||
opr::GroupManager::RegisterInfo opr_register(const std::string& key, | |||
size_t nr_devices, | |||
bool is_root, int rank, | |||
uint64_t comp_node_hash) { | |||
return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); | |||
} | |||
std::vector<std::string> gather_uid(const std::string& uid, | |||
const std::string& key, uint32_t size, uint32_t rank) { | |||
return m_mgr.gather_uid(uid, key, size, rank); | |||
} | |||
void set_output_shape(const std::string& key, | |||
const TensorShape& shape) override { | |||
m_mgr.set_output_shape(key, shape); | |||
} | |||
TensorShape get_output_shape(const std::string& key) override { | |||
return m_mgr.get_output_shape(key); | |||
} | |||
uint32_t group_barrier(uint32_t size, uint32_t rank) override { | |||
return m_mgr.group_barrier(size, rank); | |||
} | |||
private: | |||
opr::GroupManager m_mgr; | |||
}; | |||
const auto send_tag = RemoteIOBase::Type::SEND; | |||
const auto recv_tag = RemoteIOBase::Type::RECV; | |||
} // anonymous namespace | |||
const auto send_tag = opr::RemoteIOBase::Type::SEND; | |||
const auto recv_tag = opr::RemoteIOBase::Type::RECV; | |||
TEST(TestOprIORemote, Identity) { | |||
REQUIRE_GPU(2); | |||
@@ -69,7 +32,7 @@ TEST(TestOprIORemote, Identity) { | |||
auto host_x = gen({28, 28}); | |||
HostTensorND host_y; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); | |||
@@ -90,7 +53,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||
HostTensorGenerator<> gen; | |||
auto host_x = gen({2, 3}, cns[1]); | |||
HostTensorND host_x_get; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto sender = [&]() { | |||
auto graph = ComputingGraph::make(); | |||
@@ -123,7 +86,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||
HostTensorGenerator<> gen; | |||
auto host_x = gen({2, 3}, cns[0]); | |||
HostTensorND host_x_get; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto sender = [&]() { | |||
sys::set_thread_name("sender"); | |||
@@ -157,7 +120,7 @@ TEST(TestOprIORemote, APlusB) { | |||
HostTensorGenerator<> gen; | |||
auto host_x = gen({5, 7}, cns[0]), host_y = gen({5, 1}, cns[0]); | |||
HostTensorND host_z; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto sender = [&]() { | |||
auto graph = ComputingGraph::make(); | |||
@@ -208,7 +171,7 @@ TEST(TestOprIORemote, SendGrad) { | |||
HostTensorGenerator<> gen; | |||
auto host_x = gen({2, 3}, cns[0]); | |||
HostTensorND host_gx, host_loss; | |||
auto client = std::make_shared<MockGroupClient>(); | |||
auto client = std::make_shared<test::MockGroupClient>(); | |||
auto sender = [&]() { | |||
sys::set_thread_name("sender"); | |||
@@ -0,0 +1,62 @@ | |||
/** | |||
* \file src/opr-mm/test/mock_client.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/opr/group_manager.h" | |||
namespace mgb { | |||
namespace test { | |||
class MockGroupClient final : public opr::GroupClient { | |||
public: | |||
using RegisterInfo = opr::GroupManager::RegisterInfo; | |||
MockGroupClient(const std::string& server_addr = "mock_addr") : | |||
m_addr(server_addr) { | |||
} | |||
~MockGroupClient() override = default; | |||
const std::string& get_addr() const { | |||
return m_addr; | |||
} | |||
RegisterInfo opr_register(const std::string& key, size_t nr_devices, | |||
bool is_root, int rank, uint64_t comp_node_hash) { | |||
return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); | |||
} | |||
std::vector<std::string> gather_uid(const std::string& uid, | |||
const std::string& key, uint32_t size, uint32_t rank) { | |||
return m_mgr.gather_uid(uid, key, size, rank); | |||
} | |||
void set_output_shape(const std::string& key, | |||
const TensorShape& shape) override { | |||
m_mgr.set_output_shape(key, shape); | |||
} | |||
TensorShape get_output_shape(const std::string& key) override { | |||
return m_mgr.get_output_shape(key); | |||
} | |||
uint32_t group_barrier(uint32_t size, uint32_t rank) override { | |||
return m_mgr.group_barrier(size, rank); | |||
} | |||
private: | |||
const std::string m_addr; | |||
opr::GroupManager m_mgr; | |||
}; | |||
} // namespace test | |||
} // namespace mgb | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |