|
@@ -139,7 +139,10 @@ public: |
|
|
VarNode* full_grad(VarNode* out_grad, const CollectiveComm* opr) const { |
|
|
VarNode* full_grad(VarNode* out_grad, const CollectiveComm* opr) const { |
|
|
auto mode = ModeTrait::from_mode(opr->param().mode).grad_mode(); |
|
|
auto mode = ModeTrait::from_mode(opr->param().mode).grad_mode(); |
|
|
SymbolVarArray og_syms; |
|
|
SymbolVarArray og_syms; |
|
|
og_syms.push_back(out_grad); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (out_grad != nullptr) { |
|
|
|
|
|
og_syms.push_back(out_grad); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
auto&& cn = opr->output(0)->comp_node(); |
|
|
auto&& cn = opr->output(0)->comp_node(); |
|
|
|
|
|
|
|
@@ -401,6 +404,11 @@ class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait, |
|
|
class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase { |
|
|
class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase { |
|
|
MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_SUM; } |
|
|
MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_SUM; } |
|
|
|
|
|
|
|
|
|
|
|
VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override { |
|
|
|
|
|
VarNode* input = opr->is_root() ? out_grad : nullptr; |
|
|
|
|
|
return full_grad(input, opr); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
Mode grad_mode() override { return Mode::BROADCAST; } |
|
|
Mode grad_mode() override { return Mode::BROADCAST; } |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
@@ -484,6 +492,11 @@ class CollectiveComm::ModeTrait::GATHER : public ModeTrait { |
|
|
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay gather failed"); |
|
|
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay gather failed"); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override { |
|
|
|
|
|
VarNode* input = opr->is_root() ? out_grad : nullptr; |
|
|
|
|
|
return full_grad(input, opr); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
Mode grad_mode() override { return Mode::SCATTER; } |
|
|
Mode grad_mode() override { return Mode::SCATTER; } |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|