|
|
@@ -265,20 +265,21 @@ void GradKey::backward() { |
|
|
|
|
|
|
|
GradValue::ref_t GradKey::attach( |
|
|
|
ValueRef tensor, std::function<void(ValueRef)> callback) { |
|
|
|
auto grad_value = tensor.as_ref(m_value_type); |
|
|
|
if (grad_value) { |
|
|
|
mgb_assert(!tensor.cast(m_value_type).slot()->callback, "callback exists"); |
|
|
|
} else { |
|
|
|
GradSlotPtr grad_slot; |
|
|
|
auto& grad_fn = grad_slot.m_fn; |
|
|
|
grad_fn = LocalPtr<GradFn>::make(); |
|
|
|
grad_fn->m_key = shared_from_this(); |
|
|
|
grad_fn->m_slots.resize(1); |
|
|
|
grad_slot.m_index = 0; |
|
|
|
grad_value = m_value_type.make(tensor, shared_from_this(), grad_slot); |
|
|
|
// always create a new grad value |
|
|
|
GradSlotPtr grad_slot; |
|
|
|
auto& grad_fn = grad_slot.m_fn; |
|
|
|
grad_fn = LocalPtr<GradFn>::make(); |
|
|
|
grad_fn->m_key = shared_from_this(); |
|
|
|
grad_fn->m_slots.resize(1); |
|
|
|
grad_fn->m_slots[0].callback = callback; |
|
|
|
grad_slot.m_index = 0; |
|
|
|
if (auto&& grad_value = tensor.as_ref(m_value_type)) { |
|
|
|
grad_fn->m_backward.emplace<IdentityBackward>(); |
|
|
|
grad_fn->m_dests.push_back(grad_value->m_slot); |
|
|
|
tensor = grad_value->m_value; |
|
|
|
m_tape.emplace_back(grad_fn, nullptr); |
|
|
|
} |
|
|
|
grad_value->slot().m_fn->m_slots[0].callback = callback; |
|
|
|
return grad_value; |
|
|
|
return m_value_type.make(tensor, shared_from_this(), grad_slot); |
|
|
|
} |
|
|
|
|
|
|
|
void GradKey::freeze() { |
|
|
@@ -424,22 +425,17 @@ ValueRefList GradTransformation::apply_transformation( |
|
|
|
return outputs; |
|
|
|
} else if (op.is<CreateTensor>()) { |
|
|
|
return imperative::apply(op, inputs); |
|
|
|
} |
|
|
|
if (auto* attach_grad = op.as<AttachGrad>()) { |
|
|
|
auto& tensor = inputs[0]; |
|
|
|
if (auto&& grad_value = tensor.as_ref(m_value_type)) { |
|
|
|
mgb_assert(!has_key(attach_grad->key())); |
|
|
|
auto output = fallback()[0]; |
|
|
|
return record_grad(m_value_type.make(output, m_key, grad_value->slot())); |
|
|
|
} else if (!has_key(attach_grad->key())) { |
|
|
|
} else if (auto* attach_grad = op.as<AttachGrad>()) { |
|
|
|
if (!has_key(attach_grad->key())) { |
|
|
|
return fallback(); |
|
|
|
} else { |
|
|
|
GenericFunction callback = |
|
|
|
(GenericFunction&)inputs[1].cast<FunctionValue>(); |
|
|
|
auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) { |
|
|
|
auto ret = callback({&grad, 1}); |
|
|
|
assert(ret.empty()); |
|
|
|
}); |
|
|
|
auto output = |
|
|
|
attach_grad->key()->attach(inputs[0], [callback](ValueRef grad) { |
|
|
|
auto ret = callback({&grad, 1}); |
|
|
|
mgb_assert(ret.empty()); |
|
|
|
}); |
|
|
|
return {record_grad(output)}; |
|
|
|
} |
|
|
|
} else if (auto* grad_backward = op.as<GradBackward>()) { |
|
|
|