Browse Source

fix(mgb/opr-mm): fix grad func of reduce and gather

GitOrigin-RevId: 4687faef99
release-0.6
Megvii Engine Team 4 years ago
parent
commit
5e912eddbd
1 changed files with 14 additions and 1 deletions
  1. +14
    -1
      src/opr-mm/impl/collective_comm.cpp

+ 14
- 1
src/opr-mm/impl/collective_comm.cpp View File

@@ -139,7 +139,10 @@ public:
VarNode* full_grad(VarNode* out_grad, const CollectiveComm* opr) const { VarNode* full_grad(VarNode* out_grad, const CollectiveComm* opr) const {
auto mode = ModeTrait::from_mode(opr->param().mode).grad_mode(); auto mode = ModeTrait::from_mode(opr->param().mode).grad_mode();
SymbolVarArray og_syms; SymbolVarArray og_syms;
og_syms.push_back(out_grad);

if (out_grad != nullptr) {
og_syms.push_back(out_grad);
}


auto&& cn = opr->output(0)->comp_node(); auto&& cn = opr->output(0)->comp_node();


@@ -401,6 +404,11 @@ class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait,
class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase { class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase {
MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_SUM; } MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_SUM; }


VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override {
VarNode* input = opr->is_root() ? out_grad : nullptr;
return full_grad(input, opr);
}

Mode grad_mode() override { return Mode::BROADCAST; } Mode grad_mode() override { return Mode::BROADCAST; }
}; };


@@ -484,6 +492,11 @@ class CollectiveComm::ModeTrait::GATHER : public ModeTrait {
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay gather failed"); mgb_assert(status == MegRay::MEGRAY_OK, "MegRay gather failed");
} }


VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override {
VarNode* input = opr->is_root() ? out_grad : nullptr;
return full_grad(input, opr);
}

Mode grad_mode() override { return Mode::SCATTER; } Mode grad_mode() override { return Mode::SCATTER; }
}; };




Loading…
Cancel
Save