Browse Source

fix(mgb/opr-mm): use comp_node of config as default in CollectiveComm

GitOrigin-RevId: 6b43c9fc93
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
1bce857cb8
1 changed files with 18 additions and 92 deletions
  1. +18
    -92
      src/opr-mm/impl/collective_comm.cpp

+ 18
- 92
src/opr-mm/impl/collective_comm.cpp View File

@@ -107,27 +107,9 @@ protected:
}
}

static void add_output_var_all2all(CollectiveComm* opr) {
mgb_assert(opr->nr_devices() >= 2);
auto pname = get_param_name(opr->param());
// sublinear would setup opr->config if inputs.size() is 1,
// bypass this situation
mgb_assert(
!opr->config().has_comp_node_set() || opr->input().size() == 1,
"comp node should not be set in %s mode", pname);
for (auto i : opr->input()) {
opr->add_output(ssprintf("%s:%s", pname, i->cname()))
->comp_node(i->comp_node());
}
}

public:
virtual ~ModeTrait() = default;

//! add output var for the opr
virtual void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet& inp_cn) = 0;

/*!
* \brief the vars on whose comp node the computing should be performed
* if None, output vars would be used
@@ -188,11 +170,6 @@ public:
};

class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}

void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
@@ -231,11 +208,6 @@ class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait {
};

class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}

void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
@@ -292,11 +264,6 @@ protected:

class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait,
public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}

void get_output_var_shape(const CollectiveComm*,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
@@ -368,11 +335,6 @@ class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase {

class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait,
public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet& inp_cn) override {
add_output_var_all2all(opr);
}

void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
@@ -413,19 +375,6 @@ class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase {
};

class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
if (opr->input().size() > 0) {
add_output_var_all2all(opr);
return;
}

const auto& cns = opr->config().comp_node();
mgb_assert(cns.size() == 1, "exactly one comp_node expected, got %zu", cns.size());
auto pname = get_param_name(opr->param());
opr->add_output(ssprintf("%s:%s", pname, opr->key().c_str()))->comp_node(cns[0]);
}

void get_output_var_shape(const CollectiveComm*,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
@@ -462,11 +411,6 @@ class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait {
};

class CollectiveComm::ModeTrait::GATHER : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}

void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
@@ -501,19 +445,6 @@ class CollectiveComm::ModeTrait::GATHER : public ModeTrait {
};

class CollectiveComm::ModeTrait::SCATTER : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
if (opr->input().size() > 0) {
add_output_var_all2all(opr);
return;
}

const auto& cns = opr->config().comp_node();
mgb_assert(cns.size() == 1, "exactly one comp_node expected, got %zu", cns.size());
auto pname = get_param_name(opr->param());
opr->add_output(ssprintf("%s:%s", pname, opr->key().c_str()))->comp_node(cns[0]);
}

void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
@@ -537,11 +468,6 @@ class CollectiveComm::ModeTrait::SCATTER : public ModeTrait {
};

class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}

void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
@@ -617,35 +543,35 @@ CollectiveComm::CollectiveComm(
m_key(key),
m_dev_buffers(dev_buffer_arr),
m_disable{disable} {
for (auto i : inputs) {
mgb_assert(i->comp_node().device_type() == CompNode::DeviceType::CUDA,
"CollectiveComm currectly only supports CUDA");
}
for (auto i : config.comp_node()) {
mgb_assert(i.device_type() == CompNode::DeviceType::CUDA,
// add input
mgb_assert(inputs.size() <= 1, "one or zero input expected, got %zu", inputs.size());
if (inputs.size() > 0) {
mgb_assert(inputs[0]->comp_node().device_type() == CompNode::DeviceType::CUDA,
"CollectiveComm currectly only supports CUDA");
add_input({inputs[0]});
}

CompNode::UnorderedSet inp_cn;
ThinHashSet<int> inp_dev;
// add output
add_output(ssprintf("%s:%s", get_param_name(param), key.c_str()));

for (auto i : inputs) {
add_input({i});
inp_cn.insert(i->comp_node());
inp_dev.insert(
CompNodeEnv::from_comp_node(i->comp_node()).cuda_env().device);
// set comp node
const auto& cns = config.comp_node();
mgb_assert(cns.size() <= 1, "one or zero comp node expected, got %zu", cns.size());
if (cns.size() > 0) {
mgb_assert(cns[0].device_type() == CompNode::DeviceType::CUDA,
"CollectiveComm currectly only supports CUDA");
output(0)->comp_node(cns[0]);
} else {
output(0)->comp_node(inputs[0]->comp_node());
}
mgb_assert(
inp_dev.size() == inputs.size(),
"CollectiveComm inputs should not contain duplicated input device");

ModeTrait::from_mode(param.mode).add_output_var(this, inp_cn);

// set debug flag
const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG");
if (c_debug != nullptr and strcmp(c_debug, "1") == 0) {
m_debug_mode = true;
}

// deduplication
add_equivalence_component<PODHash<Param>>(&m_param);
add_equivalence_component<PODHash<size_t>>(&m_nr_devices);
m_hash = XXHash{}.update(key.data(), key.size() * sizeof(char)).digest();


Loading…
Cancel
Save