@@ -20,6 +20,9 @@ class AttachSpec: | |||||
__slots__ = "tensor", "callbacks" | __slots__ = "tensor", "callbacks" | ||||
_global_priority = 0 | |||||
class GradManager: | class GradManager: | ||||
r""" | r""" | ||||
GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode | GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode | ||||
@@ -118,6 +121,7 @@ class GradManager: | |||||
self._grad = None | self._grad = None | ||||
self._after_backward_callback = [] | self._after_backward_callback = [] | ||||
self._gradients = {} | self._gradients = {} | ||||
self._priority = None | |||||
def attach(self, tensors: Iterable[Tensor], callbacks=None): | def attach(self, tensors: Iterable[Tensor], callbacks=None): | ||||
r""" | r""" | ||||
@@ -293,6 +297,7 @@ class GradManager: | |||||
After this call, you will be able to call :meth:`backward`. | After this call, you will be able to call :meth:`backward`. | ||||
""" | """ | ||||
global _global_priority | |||||
if self._recording: | if self._recording: | ||||
raise RuntimeError("already recording") | raise RuntimeError("already recording") | ||||
grad = Grad() | grad = Grad() | ||||
@@ -300,6 +305,9 @@ class GradManager: | |||||
self._grad = grad | self._grad = grad | ||||
for spec in self._attach_specs.values(): | for spec in self._attach_specs.values(): | ||||
self._do_record(spec) | self._do_record(spec) | ||||
if self._priority is None: | |||||
grad._priority = _global_priority | |||||
_global_priority -= 1 | |||||
grad.__enter__() | grad.__enter__() | ||||
def _do_record(self, spec): | def _do_record(self, spec): | ||||
@@ -321,11 +329,14 @@ class GradManager: | |||||
After this call, you will not be able to call :meth:`backward`. | After this call, you will not be able to call :meth:`backward`. | ||||
""" | """ | ||||
global _global_priority | |||||
if self._grad is not None: | if self._grad is not None: | ||||
self._grad.__exit__(None, None, None) | self._grad.__exit__(None, None, None) | ||||
self._grad = None | self._grad = None | ||||
self._recording = False | self._recording = False | ||||
self._gradients = dict() | self._gradients = dict() | ||||
if self._priority is None: | |||||
_global_priority += 1 | |||||
def __enter__(self): | def __enter__(self): | ||||
self.record() | self.record() | ||||
@@ -333,3 +344,41 @@ class GradManager: | |||||
def __exit__(self, exc_type, exc_val, exc_tb): | def __exit__(self, exc_type, exc_val, exc_tb): | ||||
self.release() | self.release() | ||||
def __and__(self, other): | |||||
if isinstance(other, GradManager): | |||||
return GradManagerGroup([self, other]) | |||||
return NotImplemented | |||||
__rand__ = __and__ | |||||
class GradManagerGroup: | |||||
def __init__(self, gms) -> None: | |||||
self._gms = list(gms) | |||||
def merge_with(self, other): | |||||
if isinstance(other, GradManager): | |||||
other = GradManagerGroup([other]) | |||||
elif not isinstance(other, GradManagerGroup): | |||||
return NotImplemented | |||||
return GradManagerGroup([*self._gms, *other._gms]) | |||||
__and__ = merge_with | |||||
__rand__ = merge_with | |||||
__or__ = merge_with | |||||
__ror__ = merge_with | |||||
def __enter__(self): | |||||
global _global_priority | |||||
_global_priority += 1 | |||||
for gm in self._gms: | |||||
gm._priority = _global_priority | |||||
gm.record() | |||||
def __exit__(self, exc_type, exc_val, exc_tb): | |||||
global _global_priority | |||||
_global_priority -= 1 | |||||
for gm in self._gms: | |||||
gm.release() | |||||
gm._priority = None |
@@ -48,6 +48,14 @@ class Grad: | |||||
_grad_manager_dict[self._name] = self | _grad_manager_dict[self._name] = self | ||||
@property | @property | ||||
def _priority(self): | |||||
return self._impl.priority | |||||
@_priority.setter | |||||
def _priority(self, priority): | |||||
self._impl.priority = priority | |||||
@property | |||||
def _name(self): | def _name(self): | ||||
return self._impl.name | return self._impl.name | ||||
@@ -54,7 +54,7 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | for (size_t i = 0; i < ctx.nargs; ++i) { | ||||
*(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); | *(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); | ||||
*(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); | *(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); | ||||
*(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn); | |||||
*(bool_ptr++) = !ctx.args[i]->m_grad_info_dict.empty(); | |||||
} | } | ||||
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) && | mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) && | ||||
bool_ptr == reinterpret_cast<bool*>(buf + buf_size)); | bool_ptr == reinterpret_cast<bool*>(buf + buf_size)); | ||||
@@ -321,7 +321,7 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | for (size_t i = 0; i < ctx.nargs; ++i) { | ||||
inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]); | inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]); | ||||
inputs_copy_weak.push_back(inputs_copy.back().get()); | inputs_copy_weak.push_back(inputs_copy.back().get()); | ||||
inputs_copy.back()->m_grad_info = ctx.args[i]->m_grad_info; | |||||
inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict; | |||||
} | } | ||||
ApplyContext ctx_dup = ctx; | ApplyContext ctx_dup = ctx; | ||||
ctx_dup.args = inputs_copy_weak.data(); | ctx_dup.args = inputs_copy_weak.data(); | ||||
@@ -365,25 +365,19 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||||
} // namespace | } // namespace | ||||
apply_result_t apply_grad(ApplyContext& ctx) { | apply_result_t apply_grad(ApplyContext& ctx) { | ||||
std::shared_ptr<GradKey> grad_key; | |||||
std::unordered_set<std::shared_ptr<GradKey>> grad_keys; | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | for (size_t i = 0; i < ctx.nargs; ++i) { | ||||
auto* tensor = ctx.args[i]; | auto* tensor = ctx.args[i]; | ||||
if (tensor->m_grad_info.grad_fn) { | |||||
auto&& input_grad_key = tensor->m_grad_info.grad_fn->key.lock(); | |||||
// tensor is attached to a live GradKey | |||||
if (input_grad_key && input_grad_key->active) { | |||||
if (grad_key) { | |||||
if (grad_key != input_grad_key) { | |||||
PyErr_SetString(PyExc_NotImplementedError, "second order grad"); | |||||
throw pyext17::py_err_set(); | |||||
} | |||||
} else { | |||||
grad_key = std::move(input_grad_key); | |||||
if (!tensor->m_grad_info_dict.empty()) { | |||||
size_t grad_cnt = 0; | |||||
for (auto&& grad_info: tensor->m_grad_info_dict) { | |||||
auto input_grad_key = grad_info.grad_fn->key.lock(); | |||||
if (input_grad_key && input_grad_key->active && !input_grad_key->is_blocked()) { | |||||
grad_keys.insert(input_grad_key); | |||||
grad_cnt++; | |||||
} | } | ||||
} else { | |||||
// cleanup stale grad info | |||||
// under what condition? | |||||
tensor->m_grad_info = {}; | |||||
} | |||||
if (!grad_cnt) { | |||||
tensor->m_flags &= ~Flags::GRAD; | tensor->m_flags &= ~Flags::GRAD; | ||||
} | } | ||||
} else { | } else { | ||||
@@ -393,7 +387,7 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||||
ctx.flags &= ~Flags::GRAD; | ctx.flags &= ~Flags::GRAD; | ||||
if (!grad_key) { | |||||
if (grad_keys.empty()) { | |||||
return apply(ctx); | return apply(ctx); | ||||
} | } | ||||
@@ -418,54 +412,65 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||||
return backward_graph_grad_rule(ctx, grad_fn_holder); | return backward_graph_grad_rule(ctx, grad_fn_holder); | ||||
}(); | }(); | ||||
auto& grad_fn = grad_fn_holder.grad_fn; | |||||
if (!grad_fn) { | |||||
if (!grad_fn_holder.grad_fn) { | |||||
return outputs; | return outputs; | ||||
} | } | ||||
grad_fn->key = grad_key; | |||||
grad_fn->slots.resize(outputs.size()); | |||||
grad_fn->dsts.reserve(ctx.nargs); | |||||
for (auto&& grad_key: grad_keys) { | |||||
auto grad_fn = std::make_shared<GradFn>(); | |||||
grad_fn->backward = grad_fn_holder.grad_fn->backward; | |||||
grad_fn->key = grad_key; | |||||
grad_fn->slots.resize(outputs.size()); | |||||
grad_fn->dsts.reserve(ctx.nargs); | |||||
std::visit([&](auto& backward) { | |||||
using T = std::decay_t<decltype(backward)>; | |||||
if constexpr (std::is_same_v<T, std::monostate>) { | |||||
mgb_assert(0); | |||||
} else { | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
if (backward.input_has_grad(i) && input_requires_grad(ctx, i)) { | |||||
auto& input_grad_info = ctx.args[i]->m_grad_info; | |||||
grad_fn->dsts.emplace_back(input_grad_info); | |||||
// register as grad producer | |||||
grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); | |||||
} else { | |||||
grad_fn->dsts.emplace_back(); | |||||
std::visit([&](auto& backward) { | |||||
using T = std::decay_t<decltype(backward)>; | |||||
if constexpr (std::is_same_v<T, std::monostate>) { | |||||
mgb_assert(0); | |||||
} else { | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
if (backward.input_has_grad(i) && input_requires_grad(ctx, i) && ctx.args[i]->m_grad_info_dict.count(grad_key.get())) { | |||||
auto& input_grad_info = ctx.args[i]->m_grad_info_dict.at(grad_key.get()); | |||||
grad_fn->dsts.emplace_back(input_grad_info); | |||||
// register as grad producer | |||||
grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); | |||||
} else { | |||||
grad_fn->dsts.emplace_back(); | |||||
} | |||||
} | } | ||||
} | |||||
for (size_t i = 0; i < outputs.size(); ++i) { | |||||
if (backward.output_requires_grad(i)) { | |||||
if (backward.output_captured(i)) { | |||||
// avoid reference cycle [Tensor <-> GradFn] | |||||
static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new FastpathCopy()); | |||||
outputs[i] = python::apply(op, outputs[i])[0]; | |||||
for (size_t i = 0; i < outputs.size(); ++i) { | |||||
if (backward.output_requires_grad(i)) { | |||||
if (backward.output_captured(i)) { | |||||
// avoid reference cycle [Tensor <-> GradFn] | |||||
static std::shared_ptr<OpDef> op = std::make_shared<FastpathCopy>(); | |||||
outputs[i] = python::apply(op, outputs[i])[0]; | |||||
} | |||||
// populate grad info of output tensor | |||||
auto& grad_info = outputs[i]->m_grad_info_dict[grad_key.get()]; | |||||
grad_info.grad_fn = grad_fn; | |||||
grad_info.idx = i; | |||||
grad_info.insert_after(grad_key->free_vars_head); | |||||
outputs[i]->m_flags |= Flags::GRAD; | |||||
} | } | ||||
// populate grad info of output tensor | |||||
auto& grad_info = outputs[i]->m_grad_info; | |||||
grad_info.grad_fn = grad_fn; | |||||
grad_info.idx = i; | |||||
grad_info.insert_after(grad_key->free_vars_head); | |||||
outputs[i]->m_flags |= Flags::GRAD; | |||||
} | } | ||||
} | } | ||||
} | |||||
}, grad_fn->backward); | |||||
}, grad_fn->backward); | |||||
// record forward history | |||||
grad_key->tape.emplace_back(grad_fn); | |||||
// record forward history | |||||
grad_key->tape.emplace_back(grad_fn); | |||||
} | |||||
return outputs; | return outputs; | ||||
} | } | ||||
PyObject* GradKeyWrapper::get_priority() { | |||||
return py::cast(m_key->priority).release().ptr(); | |||||
} | |||||
void GradKeyWrapper::set_priority(pybind11::handle priority) { | |||||
m_key->name = py::cast<int>(priority); | |||||
} | |||||
void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { | void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { | ||||
if (nargs != 2) { | if (nargs != 2) { | ||||
throw py::type_error("expect 2 arguments"); | throw py::type_error("expect 2 arguments"); | ||||
@@ -488,24 +493,21 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) { | |||||
throw py::value_error("grad key finalized"); | throw py::value_error("grad key finalized"); | ||||
} | } | ||||
if (tensor->m_grad_info.grad_fn) { | |||||
if (tensor->m_grad_info.grad_fn->key.lock().get() != this) { | |||||
PyErr_SetString(PyExc_NotImplementedError, "second order grad"); | |||||
throw pyext17::py_err_set(); | |||||
} | |||||
if (tensor->m_grad_info->callback) { | |||||
if (tensor->m_grad_info_dict.count(this)) { | |||||
if (tensor->m_grad_info_dict.at(this)->callback) { | |||||
throw py::value_error("callback already set on this tensor"); | throw py::value_error("callback already set on this tensor"); | ||||
} | } | ||||
} else { | } else { | ||||
tensor->m_grad_info.idx = 0; | |||||
auto& grad_fn = tensor->m_grad_info.grad_fn; | |||||
auto& grad_info = tensor->m_grad_info_dict[this]; | |||||
grad_info.idx = 0; | |||||
auto& grad_fn = grad_info.grad_fn; | |||||
grad_fn = std::make_shared<GradFn>(); | grad_fn = std::make_shared<GradFn>(); | ||||
grad_fn->key = shared_from_this(); | grad_fn->key = shared_from_this(); | ||||
grad_fn->slots.resize(1); | grad_fn->slots.resize(1); | ||||
tensor->m_grad_info.insert_after(free_vars_head); | |||||
grad_info.insert_after(free_vars_head); | |||||
tensor->m_flags |= Flags::GRAD; | tensor->m_flags |= Flags::GRAD; | ||||
} | } | ||||
tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback); | |||||
tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback); | |||||
} | } | ||||
template<typename T> | template<typename T> | ||||
@@ -530,8 +532,15 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||||
active = false; | active = false; | ||||
struct CleanupGuard { | struct CleanupGuard { | ||||
GradKey* owner; | GradKey* owner; | ||||
CleanupGuard(GradKey* this_) : owner(this_) {} | |||||
~CleanupGuard() {owner->cleanup();} | |||||
size_t priority_backup; | |||||
CleanupGuard(GradKey* this_) : owner(this_) { | |||||
priority_backup = sm_min_priority; | |||||
sm_min_priority = owner->priority; | |||||
} | |||||
~CleanupGuard() { | |||||
owner->cleanup(); | |||||
sm_min_priority = priority_backup; | |||||
} | |||||
} _cleanup_guard(this); | } _cleanup_guard(this); | ||||
if (tape.empty()) return; | if (tape.empty()) return; | ||||
@@ -542,14 +551,16 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||||
} | } | ||||
for (size_t i = 0; i < tensors.size(); ++i) { | for (size_t i = 0; i < tensors.size(); ++i) { | ||||
auto& grad_info = tensors[i]->m_tensor->m_grad_info; | |||||
if (grad_info.grad_fn && grad_info.grad_fn->key.lock().get() == this) { | |||||
grad_info->grad = grads[i]->m_tensor; | |||||
if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) { | |||||
continue; | |||||
} | } | ||||
auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this); | |||||
grad_info->grad = grads[i]->m_tensor; | |||||
} | } | ||||
std::vector<std::shared_ptr<GradFn>> ref_keeper; | std::vector<std::shared_ptr<GradFn>> ref_keeper; | ||||
ref_keeper.reserve(tape.size()); | ref_keeper.reserve(tape.size()); | ||||
// back-propagation in reverse order | // back-propagation in reverse order | ||||
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { | for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { | ||||
auto&& grad_fn = tape[k].lock(); | auto&& grad_fn = tape[k].lock(); | ||||
@@ -619,13 +630,14 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) { | |||||
PyErr_SetString(PyExc_TypeError, "expect Tensor"); | PyErr_SetString(PyExc_TypeError, "expect Tensor"); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
auto&& grad_fn = tw->m_tensor->m_grad_info.grad_fn; | |||||
if (grad_fn && grad_fn->key.lock() == m_key) { | |||||
if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) { | |||||
Py_RETURN_TRUE; | Py_RETURN_TRUE; | ||||
} | } | ||||
Py_RETURN_FALSE; | Py_RETURN_FALSE; | ||||
} | } | ||||
int GradKey::sm_min_priority = 0; | |||||
GradKey::~GradKey() { | GradKey::~GradKey() { | ||||
cleanup(); | cleanup(); | ||||
} | } | ||||
@@ -635,4 +647,41 @@ std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() { | |||||
return registry; | return registry; | ||||
} | } | ||||
void GradInfoCollection::_shrink() { | |||||
auto pred = [](GradInfo& info){ return !(info.grad_fn) || info.grad_fn->key.expired(); }; | |||||
auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred); | |||||
m_storage.erase(iter, m_storage.end()); | |||||
} | |||||
bool GradInfoCollection::contains(GradKey* key) { | |||||
_shrink(); | |||||
for (auto&& grad_info: m_storage) { | |||||
if (grad_info.grad_fn->key.lock().get() == key) { | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
GradInfo& GradInfoCollection::operator[](GradKey* key) { | |||||
_shrink(); | |||||
for (auto&& grad_info: m_storage) { | |||||
if (grad_info.grad_fn->key.lock().get() == key) { | |||||
return grad_info; | |||||
} | |||||
} | |||||
m_storage.emplace_back(); | |||||
return m_storage.back(); | |||||
} | |||||
GradInfo& GradInfoCollection::at(GradKey* key) { | |||||
_shrink(); | |||||
for (auto&& grad_info: m_storage) { | |||||
if (grad_info.grad_fn->key.lock().get() == key) { | |||||
return grad_info; | |||||
} | |||||
} | |||||
mgb_assert(false); | |||||
} | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python |
@@ -26,12 +26,18 @@ struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj { | |||||
bool active = true; | bool active = true; | ||||
GradInfo::head_t free_vars_head; | GradInfo::head_t free_vars_head; | ||||
std::vector<std::weak_ptr<GradFn>> tape; | std::vector<std::weak_ptr<GradFn>> tape; | ||||
int priority = 0; | |||||
~GradKey(); | ~GradKey(); | ||||
void attach(Tensor* tensor, pybind11::object callback); | void attach(Tensor* tensor, pybind11::object callback); | ||||
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | ||||
void cleanup(); | void cleanup(); | ||||
bool is_blocked() const { | |||||
return priority < sm_min_priority; | |||||
} | |||||
private: | |||||
static int sm_min_priority; | |||||
}; | }; | ||||
struct GradKeyWrapper { | struct GradKeyWrapper { | ||||
@@ -44,6 +50,8 @@ struct GradKeyWrapper { | |||||
PyObject* get_name(); | PyObject* get_name(); | ||||
void set_name(pybind11::handle name); | void set_name(pybind11::handle name); | ||||
PyObject* get_priority(); | |||||
void set_priority(pybind11::handle priority); | |||||
void attach(PyObject*const* args, size_t nargs); | void attach(PyObject*const* args, size_t nargs); | ||||
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | ||||
PyObject* is_attached_to(PyObject*const* args, size_t nargs); | PyObject* is_attached_to(PyObject*const* args, size_t nargs); | ||||
@@ -150,7 +158,7 @@ using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::M | |||||
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry(); | std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry(); | ||||
inline bool input_requires_grad(const ApplyContext& ctx, size_t i) { | inline bool input_requires_grad(const ApplyContext& ctx, size_t i) { | ||||
return bool(ctx.args[i]->m_grad_info.grad_fn); | |||||
return !ctx.args[i]->m_grad_info_dict.empty(); | |||||
} | } | ||||
struct GradRuleFallback : std::exception {}; | struct GradRuleFallback : std::exception {}; | ||||
@@ -15,6 +15,7 @@ | |||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
struct GradKey; | |||||
struct GradFn; | struct GradFn; | ||||
struct GradSlot; | struct GradSlot; | ||||
@@ -32,6 +33,10 @@ struct GradInfo : GradSlotPtr, intrusive_list::Node<GradInfo, intrusive_list::be | |||||
GradInfo(GradInfo&&) = default; | GradInfo(GradInfo&&) = default; | ||||
GradInfo& operator=(GradInfo&) = default; | GradInfo& operator=(GradInfo&) = default; | ||||
GradInfo& operator=(GradInfo&&) = default; | GradInfo& operator=(GradInfo&&) = default; | ||||
GradInfo(const GradInfo& rhs): GradInfo(const_cast<GradInfo&>(rhs)){} | |||||
GradInfo& operator=(const GradInfo& rhs) { | |||||
return *this = const_cast<GradInfo&>(rhs); | |||||
} | |||||
}; | }; | ||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python |
@@ -182,7 +182,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ | if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ | ||||
SmallVector<cg::VarNode*> vinputs(nargs); | SmallVector<cg::VarNode*> vinputs(nargs); | ||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node; | |||||
vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node; | |||||
} | } | ||||
auto op = ctx.op.get(); | auto op = ctx.op.get(); | ||||
auto rst = OpDef::apply_on_var_node(*op, vinputs); | auto rst = OpDef::apply_on_var_node(*op, vinputs); | ||||
@@ -17,6 +17,7 @@ | |||||
#include "megbrain/imperative/interpreter.h" | #include "megbrain/imperative/interpreter.h" | ||||
#include "pybind11/pybind11.h" | #include "pybind11/pybind11.h" | ||||
#include <string> | #include <string> | ||||
#include <unordered_map> | |||||
#include "./pyext17.h" | #include "./pyext17.h" | ||||
@@ -36,6 +37,8 @@ struct ObjectPtr : B { | |||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
struct GradKey; | |||||
extern interpreter::Interpreter::Channel* interpreter_for_py; | extern interpreter::Interpreter::Channel* interpreter_for_py; | ||||
class SharedHandle { | class SharedHandle { | ||||
@@ -58,6 +61,34 @@ public: | |||||
}; | }; | ||||
// impl in grad.cpp | |||||
class GradInfoCollection { | |||||
private: | |||||
SmallVector<GradInfo> m_storage; | |||||
protected: | |||||
void _shrink(); | |||||
public: | |||||
bool contains(GradKey* key); | |||||
GradInfo& operator[](GradKey* key); | |||||
GradInfo& at(GradKey* key); | |||||
bool empty() { | |||||
_shrink(); | |||||
return m_storage.empty(); | |||||
} | |||||
auto begin() { | |||||
_shrink(); | |||||
return m_storage.begin(); | |||||
} | |||||
auto end() { | |||||
_shrink(); | |||||
return m_storage.end(); | |||||
} | |||||
size_t count(GradKey* key) { | |||||
return contains(key) ? 1 : 0; | |||||
} | |||||
}; | |||||
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | ||||
using flags_t = uint64_t; | using flags_t = uint64_t; | ||||
@@ -69,7 +100,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||||
flags_t m_flags = 0; | flags_t m_flags = 0; | ||||
GradInfo m_grad_info; | |||||
GradInfoCollection m_grad_info_dict; | |||||
TraceInfo m_trace_info; | TraceInfo m_trace_info; | ||||
SharedHandle m_handle; | SharedHandle m_handle; | ||||
std::string user_custom_name; | std::string user_custom_name; | ||||
@@ -88,7 +119,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||||
inline std::shared_ptr<Tensor> copy() { | inline std::shared_ptr<Tensor> copy() { | ||||
auto ret = std::make_shared<Tensor>(m_handle); | auto ret = std::make_shared<Tensor>(m_handle); | ||||
ret->m_flags = m_flags; | ret->m_flags = m_flags; | ||||
ret->m_grad_info = m_grad_info; | |||||
ret->m_grad_info_dict = m_grad_info_dict; | |||||
ret->m_trace_info = m_trace_info; | ret->m_trace_info = m_trace_info; | ||||
ret->m_var = m_var; | ret->m_var = m_var; | ||||
return ret; | return ret; | ||||
@@ -108,21 +108,24 @@ def test_grad_2(): | |||||
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | ||||
@pytest.mark.skip(reason="high order gradient was not implemented yet") | |||||
def test_2nd_grad(): | def test_2nd_grad(): | ||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = as_tensor(x_np) | x = as_tensor(x_np) | ||||
ones = as_tensor(np.ones_like(x_np)) | ones = as_tensor(np.ones_like(x_np)) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
grad._priority = -1 | |||||
grad2 = Grad().wrt(x, callback=save_to(x)) | grad2 = Grad().wrt(x, callback=save_to(x)) | ||||
grad2._priority = 0 | |||||
y = cos(x) | y = cos(x) | ||||
grad(y, ones) | grad(y, ones) | ||||
z = x.grad | |||||
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | ||||
grad2(x.grad, ones) | |||||
x.grad = None | |||||
grad2(z, ones) | |||||
np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np)) | np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np)) | ||||
@@ -398,20 +398,6 @@ OP_TRAIT_REG(Copy, Copy) | |||||
.fallback(); | .fallback(); | ||||
}} // copy | }} // copy | ||||
namespace { namespace identity { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& op = def.cast_final_safe<Identity>(); | |||||
mgb_assert(inputs.size() == 1); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::Identity::make(inputs[0], config); | |||||
} | |||||
OP_TRAIT_REG(Identity, Identity) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // identity | |||||
namespace { namespace assert_equal { | namespace { namespace assert_equal { | ||||
auto apply_on_var_node( | auto apply_on_var_node( | ||||
const OpDef& def, | const OpDef& def, | ||||
@@ -9,6 +9,7 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/imperative/ops/utility.h" | #include "megbrain/imperative/ops/utility.h" | ||||
#include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
#include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
@@ -32,4 +33,25 @@ OP_TRAIT_REG(FastpathCopy,FastpathCopy) | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(FastpathCopy); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(FastpathCopy); | ||||
namespace { namespace identity { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& op = def.cast_final_safe<Identity>(); | |||||
mgb_assert(inputs.size() == 1); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::Identity::make(inputs[0], config); | |||||
} | |||||
auto apply_on_physical_tensor( | |||||
const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs) { | |||||
return SmallVector<TensorPtr>{inputs[0]}; | |||||
} | |||||
OP_TRAIT_REG(Identity, Identity) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||||
.fallback(); | |||||
}} // identity | |||||
} // namespace mgb::imperative | } // namespace mgb::imperative |