|
|
@@ -107,27 +107,9 @@ protected: |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
static void add_output_var_all2all(CollectiveComm* opr) { |
|
|
|
mgb_assert(opr->nr_devices() >= 2); |
|
|
|
auto pname = get_param_name(opr->param()); |
|
|
|
// sublinear would setup opr->config if inputs.size() is 1, |
|
|
|
// bypass this situation |
|
|
|
mgb_assert( |
|
|
|
!opr->config().has_comp_node_set() || opr->input().size() == 1, |
|
|
|
"comp node should not be set in %s mode", pname); |
|
|
|
for (auto i : opr->input()) { |
|
|
|
opr->add_output(ssprintf("%s:%s", pname, i->cname())) |
|
|
|
->comp_node(i->comp_node()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public: |
|
|
|
virtual ~ModeTrait() = default; |
|
|
|
|
|
|
|
//! add output var for the opr |
|
|
|
virtual void add_output_var(CollectiveComm* opr, |
|
|
|
const CompNode::UnorderedSet& inp_cn) = 0; |
|
|
|
|
|
|
|
/*! |
|
|
|
* \brief the vars on whose comp node the computing should be performed |
|
|
|
* if None, output vars would be used |
|
|
@@ -188,11 +170,6 @@ public: |
|
|
|
}; |
|
|
|
|
|
|
|
class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait { |
|
|
|
void add_output_var(CollectiveComm* opr, |
|
|
|
const CompNode::UnorderedSet&) override { |
|
|
|
add_output_var_all2all(opr); |
|
|
|
} |
|
|
|
|
|
|
|
void get_output_var_shape(const CollectiveComm* opr, |
|
|
|
const TensorShapeArray& ishp, |
|
|
|
TensorShapeArray& oshp) override { |
|
|
@@ -231,11 +208,6 @@ class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait { |
|
|
|
}; |
|
|
|
|
|
|
|
class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait { |
|
|
|
void add_output_var(CollectiveComm* opr, |
|
|
|
const CompNode::UnorderedSet&) override { |
|
|
|
add_output_var_all2all(opr); |
|
|
|
} |
|
|
|
|
|
|
|
void get_output_var_shape(const CollectiveComm* opr, |
|
|
|
const TensorShapeArray& ishp, |
|
|
|
TensorShapeArray& oshp) override { |
|
|
@@ -292,11 +264,6 @@ protected: |
|
|
|
|
|
|
|
class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait, |
|
|
|
public ModeTrait { |
|
|
|
void add_output_var(CollectiveComm* opr, |
|
|
|
const CompNode::UnorderedSet&) override { |
|
|
|
add_output_var_all2all(opr); |
|
|
|
} |
|
|
|
|
|
|
|
void get_output_var_shape(const CollectiveComm*, |
|
|
|
const TensorShapeArray& ishp, |
|
|
|
TensorShapeArray& oshp) override { |
|
|
@@ -368,11 +335,6 @@ class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase { |
|
|
|
|
|
|
|
class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait, |
|
|
|
public ModeTrait { |
|
|
|
void add_output_var(CollectiveComm* opr, |
|
|
|
const CompNode::UnorderedSet& inp_cn) override { |
|
|
|
add_output_var_all2all(opr); |
|
|
|
} |
|
|
|
|
|
|
|
void get_output_var_shape(const CollectiveComm* opr, |
|
|
|
const TensorShapeArray& ishp, |
|
|
|
TensorShapeArray& oshp) override { |
|
|
@@ -413,19 +375,6 @@ class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase { |
|
|
|
}; |
|
|
|
|
|
|
|
class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait { |
|
|
|
void add_output_var(CollectiveComm* opr, |
|
|
|
const CompNode::UnorderedSet&) override { |
|
|
|
if (opr->input().size() > 0) { |
|
|
|
add_output_var_all2all(opr); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
const auto& cns = opr->config().comp_node(); |
|
|
|
mgb_assert(cns.size() == 1, "exactly one comp_node expected, got %zu", cns.size()); |
|
|
|
auto pname = get_param_name(opr->param()); |
|
|
|
opr->add_output(ssprintf("%s:%s", pname, opr->key().c_str()))->comp_node(cns[0]); |
|
|
|
} |
|
|
|
|
|
|
|
void get_output_var_shape(const CollectiveComm*, |
|
|
|
const TensorShapeArray& ishp, |
|
|
|
TensorShapeArray& oshp) override { |
|
|
@@ -462,11 +411,6 @@ class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait { |
|
|
|
}; |
|
|
|
|
|
|
|
class CollectiveComm::ModeTrait::GATHER : public ModeTrait { |
|
|
|
void add_output_var(CollectiveComm* opr, |
|
|
|
const CompNode::UnorderedSet&) override { |
|
|
|
add_output_var_all2all(opr); |
|
|
|
} |
|
|
|
|
|
|
|
void get_output_var_shape(const CollectiveComm* opr, |
|
|
|
const TensorShapeArray& ishp, |
|
|
|
TensorShapeArray& oshp) override { |
|
|
@@ -501,19 +445,6 @@ class CollectiveComm::ModeTrait::GATHER : public ModeTrait { |
|
|
|
}; |
|
|
|
|
|
|
|
class CollectiveComm::ModeTrait::SCATTER : public ModeTrait { |
|
|
|
void add_output_var(CollectiveComm* opr, |
|
|
|
const CompNode::UnorderedSet&) override { |
|
|
|
if (opr->input().size() > 0) { |
|
|
|
add_output_var_all2all(opr); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
const auto& cns = opr->config().comp_node(); |
|
|
|
mgb_assert(cns.size() == 1, "exactly one comp_node expected, got %zu", cns.size()); |
|
|
|
auto pname = get_param_name(opr->param()); |
|
|
|
opr->add_output(ssprintf("%s:%s", pname, opr->key().c_str()))->comp_node(cns[0]); |
|
|
|
} |
|
|
|
|
|
|
|
void get_output_var_shape(const CollectiveComm* opr, |
|
|
|
const TensorShapeArray& ishp, |
|
|
|
TensorShapeArray& oshp) override { |
|
|
@@ -537,11 +468,6 @@ class CollectiveComm::ModeTrait::SCATTER : public ModeTrait { |
|
|
|
}; |
|
|
|
|
|
|
|
class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait { |
|
|
|
void add_output_var(CollectiveComm* opr, |
|
|
|
const CompNode::UnorderedSet&) override { |
|
|
|
add_output_var_all2all(opr); |
|
|
|
} |
|
|
|
|
|
|
|
void get_output_var_shape(const CollectiveComm* opr, |
|
|
|
const TensorShapeArray& ishp, |
|
|
|
TensorShapeArray& oshp) override { |
|
|
@@ -617,35 +543,35 @@ CollectiveComm::CollectiveComm( |
|
|
|
m_key(key), |
|
|
|
m_dev_buffers(dev_buffer_arr), |
|
|
|
m_disable{disable} { |
|
|
|
for (auto i : inputs) { |
|
|
|
mgb_assert(i->comp_node().device_type() == CompNode::DeviceType::CUDA, |
|
|
|
"CollectiveComm currectly only supports CUDA"); |
|
|
|
} |
|
|
|
for (auto i : config.comp_node()) { |
|
|
|
mgb_assert(i.device_type() == CompNode::DeviceType::CUDA, |
|
|
|
// add input |
|
|
|
mgb_assert(inputs.size() <= 1, "one or zero input expected, got %zu", inputs.size()); |
|
|
|
if (inputs.size() > 0) { |
|
|
|
mgb_assert(inputs[0]->comp_node().device_type() == CompNode::DeviceType::CUDA, |
|
|
|
"CollectiveComm currectly only supports CUDA"); |
|
|
|
add_input({inputs[0]}); |
|
|
|
} |
|
|
|
|
|
|
|
CompNode::UnorderedSet inp_cn; |
|
|
|
ThinHashSet<int> inp_dev; |
|
|
|
// add output |
|
|
|
add_output(ssprintf("%s:%s", get_param_name(param), key.c_str())); |
|
|
|
|
|
|
|
for (auto i : inputs) { |
|
|
|
add_input({i}); |
|
|
|
inp_cn.insert(i->comp_node()); |
|
|
|
inp_dev.insert( |
|
|
|
CompNodeEnv::from_comp_node(i->comp_node()).cuda_env().device); |
|
|
|
// set comp node |
|
|
|
const auto& cns = config.comp_node(); |
|
|
|
mgb_assert(cns.size() <= 1, "one or zero comp node expected, got %zu", cns.size()); |
|
|
|
if (cns.size() > 0) { |
|
|
|
mgb_assert(cns[0].device_type() == CompNode::DeviceType::CUDA, |
|
|
|
"CollectiveComm currectly only supports CUDA"); |
|
|
|
output(0)->comp_node(cns[0]); |
|
|
|
} else { |
|
|
|
output(0)->comp_node(inputs[0]->comp_node()); |
|
|
|
} |
|
|
|
mgb_assert( |
|
|
|
inp_dev.size() == inputs.size(), |
|
|
|
"CollectiveComm inputs should not contain duplicated input device"); |
|
|
|
|
|
|
|
ModeTrait::from_mode(param.mode).add_output_var(this, inp_cn); |
|
|
|
|
|
|
|
// set debug flag |
|
|
|
const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG"); |
|
|
|
if (c_debug != nullptr and strcmp(c_debug, "1") == 0) { |
|
|
|
m_debug_mode = true; |
|
|
|
} |
|
|
|
|
|
|
|
// deduplication |
|
|
|
add_equivalence_component<PODHash<Param>>(&m_param); |
|
|
|
add_equivalence_component<PODHash<size_t>>(&m_nr_devices); |
|
|
|
m_hash = XXHash{}.update(key.data(), key.size() * sizeof(char)).digest(); |
|
|
|