Browse Source

feat(imperative): add inplace add_update option in optimizer

GitOrigin-RevId: b8feb49321
release-1.2
Megvii Engine Team 4 years ago
parent
commit
c49427d15a
23 changed files with 337 additions and 105 deletions
  1. +15
    -0
      imperative/python/megengine/functional/inplace.py
  2. +4
    -2
      imperative/python/megengine/jit/tracing.py
  3. +46
    -15
      imperative/python/megengine/optimizer/adam.py
  4. +2
    -1
      imperative/python/megengine/optimizer/optimizer.py
  5. +21
    -2
      imperative/python/megengine/optimizer/sgd.py
  6. +2
    -2
      imperative/python/megengine/tensor.py
  7. +2
    -2
      imperative/python/src/grad_override.cpp
  8. +6
    -5
      imperative/python/src/tensor.cpp
  9. +14
    -4
      imperative/python/test/integration/test_sgd_momentum.py
  10. +0
    -1
      imperative/python/test/unit/test_tracing.py
  11. +39
    -5
      imperative/src/impl/dnn_op_helper.h
  12. +4
    -3
      imperative/src/impl/interpreter_impl.cpp
  13. +3
    -2
      imperative/src/impl/interpreter_impl.h
  14. +1
    -39
      imperative/src/impl/ops/cond_take.cpp
  15. +133
    -0
      imperative/src/impl/ops/elemwise.cpp
  16. +14
    -14
      imperative/src/impl/proxy_graph_detail.cpp
  17. +1
    -1
      imperative/src/include/megbrain/imperative/interpreter.h
  18. +4
    -0
      imperative/src/include/megbrain/imperative/physical_tensor.h
  19. +4
    -0
      imperative/src/include/megbrain/imperative/proxy_graph_detail.h
  20. +2
    -0
      src/core/include/megbrain/ir/ops.td
  21. +2
    -5
      src/opr/impl/basic_arith.cpp
  22. +16
    -0
      src/opr/impl/internal/identical_fwd.cpp
  23. +2
    -2
      src/opr/include/megbrain/opr/internal/identical_fwd.h

+ 15
- 0
imperative/python/megengine/functional/inplace.py View File

@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..core.ops.builtin import InplaceAdd


def _inplace_add_(dest, delta, alpha, beta):
return dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0])

+ 4
- 2
imperative/python/megengine/jit/tracing.py View File

@@ -502,6 +502,8 @@ class trace:
# profile # profile
if self._profiling: if self._profiling:
self._profiler = GraphProfiler(graph) self._profiler = GraphProfiler(graph)
if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")):
graph.options.var_sanity_check_first_run = False


def _compile(self): def _compile(self):
graph = self._graph = G.Graph() graph = self._graph = G.Graph()
@@ -1073,7 +1075,7 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor):
return active_trace._apply_op(op, args) return active_trace._apply_op(op, args)




def apply_const_compiled_mode(value, dtype, device, is_const):
def apply_const_compiled_mode(value, dtype, device, is_const, no_cache):
if skip_tracing: if skip_tracing:
args = [ args = [
RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
@@ -1099,7 +1101,7 @@ def apply_with_tracing(op: OpDef, *args: RawTensor):
return list(outputs) return list(outputs)




def apply_const_with_tracing(value, dtype, device, is_const):
def apply_const_with_tracing(value, dtype, device, is_const, no_cache):
if active_trace._symbolic: if active_trace._symbolic:
outputs = apply_const_symbolic_mode(value, dtype, device) outputs = apply_const_symbolic_mode(value, dtype, device)
else: else:


+ 46
- 15
imperative/python/megengine/optimizer/adam.py View File

@@ -6,8 +6,10 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
from typing import Iterable, Tuple, Union from typing import Iterable, Tuple, Union


from ..functional.inplace import _inplace_add_
from ..tensor import Parameter, tensor from ..tensor import Parameter, tensor
from .optimizer import Optimizer from .optimizer import Optimizer


@@ -58,15 +60,24 @@ class Adam(Optimizer):
eps = param_group["eps"] eps = param_group["eps"]
beta0, beta1 = param_group["betas"] beta0, beta1 = param_group["betas"]


def make_scalar(val):
return tensor([val])

# since `conver_inputs` is disabled for param updates, # since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor # scalar should be explicitly tansforred to tensor
_lr = tensor([lr])
_weight_decay = tensor([weight_decay])
_eps = tensor([eps])
_beta0, _beta1 = tensor([beta0]), tensor([beta1])


c1 = tensor([1.0])
c05 = tensor([0.5])
_lr, _neg_lr = map(make_scalar, (lr, -lr))
_weight_decay = make_scalar(weight_decay)
_eps = make_scalar(eps)
_beta0, _beta1 = map(make_scalar, (beta0, beta1))

c1, c05 = map(make_scalar, (1.0, 0.5))

inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0"))
if inplace_mode:
# reduce device sync
c1_sub_beta0, c1_sub_beta1 = map(make_scalar, (1 - beta0, 1 - beta1))

for param in param_group["params"]: for param in param_group["params"]:


if param.grad is None: if param.grad is None:
@@ -77,18 +88,38 @@ class Adam(Optimizer):
grad += param * _weight_decay grad += param * _weight_decay


states = self._state[param] states = self._state[param]
step = states["step"]

step, exp_avg, exp_avg_sq = (
states["step"],
states["exp_avg"],
states["exp_avg_sq"],
)

if inplace_mode:
_inplace_add_(step, c1, alpha=c1, beta=c1)
_inplace_add_(exp_avg, grad, alpha=_beta0, beta=c1_sub_beta0)
_inplace_add_(
exp_avg_sq, grad * grad, alpha=_beta1, beta=c1_sub_beta1,
)

delta = (exp_avg / (c1 - _beta0 ** step)) / (
(exp_avg_sq / (c1 - _beta1 ** step)) ** c05 + _eps
)
_inplace_add_(param, delta, alpha=c1, beta=_neg_lr)
continue

# step = step + c1
step += c1 step += c1
exp_avg = states["exp_avg"]
exp_avg_sq = states["exp_avg_sq"]
exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0)
exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad)

# exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0)
exp_avg *= _beta0
exp_avg += grad * (c1 - _beta0)

# exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad)
exp_avg_sq *= _beta1
exp_avg_sq += (c1 - _beta1) * (grad * grad)


delta = (exp_avg / (c1 - _beta0 ** step)) / ( delta = (exp_avg / (c1 - _beta0 ** step)) / (
(exp_avg_sq / (c1 - _beta1 ** step)) ** c05 + _eps (exp_avg_sq / (c1 - _beta1 ** step)) ** c05 + _eps
) )
param -= _lr * delta param -= _lr * delta

# not inplace change, need to update underlying tensor handler in state
states["exp_avg"]._reset(exp_avg)
states["exp_avg_sq"]._reset(exp_avg_sq)

+ 2
- 1
imperative/python/megengine/optimizer/optimizer.py View File

@@ -96,6 +96,7 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is " "optimizer can only optimize Parameters, but one of the params is "
+ str(type(param)) + str(type(param))
) )
param._reset(Tensor(param.numpy(), no_cache=True))


for name, default in self._defaults.items(): for name, default in self._defaults.items():
if default is required and name not in param_group: if default is required and name not in param_group:
@@ -121,7 +122,7 @@ class Optimizer(metaclass=ABCMeta):
initializer = np.zeros(param.shape, dtype=np.float32) initializer = np.zeros(param.shape, dtype=np.float32)
state_dict = self._state.setdefault(param, {}) state_dict = self._state.setdefault(param, {})
assert state_name not in state_dict assert state_name not in state_dict
state = Tensor(initializer)
state = Tensor(initializer, no_cache=True)
state_dict[state_name] = state state_dict[state_name] = state


@abstractmethod @abstractmethod


+ 21
- 2
imperative/python/megengine/optimizer/sgd.py View File

@@ -6,8 +6,10 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
from typing import Iterable, Union from typing import Iterable, Union


from ..functional.inplace import _inplace_add_
from ..tensor import Parameter, tensor from ..tensor import Parameter, tensor
from .optimizer import Optimizer from .optimizer import Optimizer


@@ -54,10 +56,16 @@ class SGD(Optimizer):


# since `conver_inputs` is disabled for param updates, # since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor # scalar should be explicitly tansforred to tensor

_lr = tensor([lr]) _lr = tensor([lr])
_weight_decay = tensor([weight_decay]) _weight_decay = tensor([weight_decay])
_momentum = tensor([momentum]) _momentum = tensor([momentum])


inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0"))
if inplace_mode:
_neg_lr = tensor([-lr])
c1 = tensor([1.0])

for param in param_group["params"]: for param in param_group["params"]:
if param.grad is None: if param.grad is None:
continue continue
@@ -66,10 +74,21 @@ class SGD(Optimizer):
if weight_decay != 0.0: if weight_decay != 0.0:
grad += param * _weight_decay grad += param * _weight_decay


if inplace_mode:
if momentum:
v = self._state[param]["momentum_buffer"]
_inplace_add_(v, grad, alpha=_momentum, beta=c1)
_inplace_add_(param, v, alpha=c1, beta=_neg_lr)
else:
_inplace_add_(param, grad, alpha=c1, beta=_neg_lr)
continue

if momentum: if momentum:
v = self._state[param]["momentum_buffer"] v = self._state[param]["momentum_buffer"]
v = _momentum * v + grad
# v = v * _momentum + grad
v *= _momentum
v += grad

param -= _lr * v param -= _lr * v
self._state[param]["momentum_buffer"]._reset(v)
else: else:
param -= _lr * grad param -= _lr * grad

+ 2
- 2
imperative/python/megengine/tensor.py View File

@@ -28,7 +28,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
dmap_callback = None dmap_callback = None
q_dict = {"mode": None, "scale": None, "zero_point": None} q_dict = {"mode": None, "scale": None, "zero_point": None}


def __new__(cls, data, dtype=None, device=None, is_const=False):
def __new__(cls, data, dtype=None, device=None, is_const=False, no_cache=False):
if device is None: if device is None:
cn = get_default_device() cn = get_default_device()
elif isinstance(device, str): elif isinstance(device, str):
@@ -49,7 +49,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
if 0 in data.strides: if 0 in data.strides:
data = data.squeeze().reshape(data.shape) data = data.squeeze().reshape(data.shape)


obj = _Tensor.__new__(cls, data, dtype, cn, is_const)
obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache)
return obj return obj


@property @property


+ 2
- 2
imperative/python/src/grad_override.cpp View File

@@ -38,9 +38,9 @@ std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) {
std::shared_ptr<Tensor> make_tensor(CompNode cn, Tensor* shape, float v = 0) { std::shared_ptr<Tensor> make_tensor(CompNode cn, Tensor* shape, float v = 0) {
HostTensorND scalar{cn, {{1}, dtype::Float32()}}; HostTensorND scalar{cn, {{1}, dtype::Float32()}};
scalar.ptr<float>()[0] = v; scalar.ptr<float>()[0] = v;
interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar);
interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false);
auto&& t = std::make_shared<Tensor>(handle); auto&& t = std::make_shared<Tensor>(handle);
auto&& res = broadcast_to(t.get(), shape);
auto res = broadcast_to(t.get(), shape);
return res; return res;
} }




+ 6
- 5
imperative/python/src/tensor.cpp View File

@@ -231,13 +231,14 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
} }
} else { } else {
py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
if (nargs != 4 && nargs != 5) {
throw py::type_error("expect 4 or 5 arguments");
}
auto data = tup[0].cast<py::array>(); auto data = tup[0].cast<py::array>();
DType dtype = tup[1].cast<DType>(); DType dtype = tup[1].cast<DType>();
CompNode cn = tup[2].cast<CompNode>(); CompNode cn = tup[2].cast<CompNode>();
bool is_const = tup[3].cast<bool>(); bool is_const = tup[3].cast<bool>();
if (nargs != 4) {
throw py::type_error("expect 3 arguments");
}
bool no_cache = nargs == 5 ? tup[4].cast<bool>() : false;


// const op // const op
if (is_const && is_tracing) { if (is_const && is_tracing) {
@@ -259,10 +260,10 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
interpreter::Interpreter::Handle handle; interpreter::Interpreter::Handle handle;
constexpr auto size_threshhold = TensorShape::MAX_NDIM; constexpr auto size_threshhold = TensorShape::MAX_NDIM;
if (data.size() > size_threshhold) { if (data.size() > size_threshhold) {
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype));
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype), no_cache);
} else { } else {
HostTensorND ret(cn); HostTensorND ret(cn);
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype));
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype), no_cache);
} }


m_tensor = std::make_shared<Tensor>(handle); m_tensor = std::make_shared<Tensor>(handle);


+ 14
- 4
imperative/python/test/integration/test_sgd_momentum.py View File

@@ -6,6 +6,9 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import itertools
import os

import numpy as np import numpy as np


import megengine import megengine
@@ -58,13 +61,16 @@ def test_sgd_momentum():


np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34
optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34, 5
) )




def test_sgd_momentum_trace(): def test_sgd_momentum_trace():

for symbolic in (True, False):
origin_inplace = os.getenv("MEGENGINE_INPLACE_UPDATE")
symbolic = (True, False)
inplace = (0, 1)
for symbolic, inplace in itertools.product(symbolic, inplace):
os.environ["MEGENGINE_INPLACE_UPDATE"] = str(inplace)


@trace(symbolic=symbolic) @trace(symbolic=symbolic)
def train_func(data, *, model=None, optim=None, gm=None): def train_func(data, *, model=None, optim=None, gm=None):
@@ -101,5 +107,9 @@ def test_sgd_momentum_trace():
train_func(data, model=net, optim=optim, gm=gm) train_func(data, model=net, optim=optim, gm=gm)
np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34
optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34, 5
) )
if origin_inplace:
os.environ["MEGENGINE_INPLACE_UPDATE"] = origin_inplace
else:
del os.environ["MEGENGINE_INPLACE_UPDATE"]

+ 0
- 1
imperative/python/test/unit/test_tracing.py View File

@@ -325,7 +325,6 @@ def test_raise_on_trace():


@trace @trace
def add_abc(a, b, c): def add_abc(a, b, c):
print("Hello")
ps = a + b ps = a + b
result = ps + c result = ps + c
if step_count == bad_step: if step_count == bad_step:


+ 39
- 5
imperative/src/impl/dnn_op_helper.h View File

@@ -11,6 +11,7 @@


#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/comp_node.h" #include "megbrain/comp_node.h"
#include "megbrain/imperative/physical_tensor.h"


using namespace megdnn; using namespace megdnn;


@@ -29,19 +30,21 @@ struct DnnOprCaller {
Workspace workspace; Workspace workspace;
std::unique_ptr<Opr> op; std::unique_ptr<Opr> op;


DnnOprCaller(CompNode cn): cn(cn) {
DnnOprCaller(CompNode cn): cn(cn), op(create_operator(cn)) {}

static std::unique_ptr<Opr> create_operator(CompNode cn) {
auto&& handle = MegDNNHandle::get( auto&& handle = MegDNNHandle::get(
CompNodeEnv::from_comp_node(cn)).handle(); CompNodeEnv::from_comp_node(cn)).handle();
op = handle->create_operator<Opr>();
return handle->create_operator<Opr>();
} }


megdnn::Workspace create_workspace(TensorLayout layout) { megdnn::Workspace create_workspace(TensorLayout layout) {
dev_tensor = Tensor::make(layout, cn)->dev_tensor(); dev_tensor = Tensor::make(layout, cn)->dev_tensor();
workspace = megdnn::Workspace(dev_tensor.raw_ptr(),
workspace = megdnn::Workspace(dev_tensor.raw_ptr(),
dev_tensor.storage().size()); dev_tensor.storage().size());
return workspace; return workspace;
} }
~DnnOprCaller() { ~DnnOprCaller() {
using DT = CompNode::DeviceType; using DT = CompNode::DeviceType;
if (cn.device_type() == DT::CPU && cn != CompNode::default_cpu()) { if (cn.device_type() == DT::CPU && cn != CompNode::default_cpu()) {
@@ -52,5 +55,36 @@ struct DnnOprCaller {
} }
}; };


template <size_t OSize>
class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy {
using Output = std::array<TensorPtr, OSize>;

CompNode m_cn;
Output m_out;

public:
MegDNNDynOutMallocImpl(CompNode cn): m_cn{cn} {}

megdnn::TensorND alloc_output(
size_t id, DType dtype, const TensorShape &shape,
void *user_data) override {
TensorLayout m_layout(shape, dtype);
m_out[id] = Tensor::make(m_layout, m_cn);
return m_out[id]->dev_tensor().as_megdnn();
}

void* alloc_workspace(size_t sz, void *user_data) override {
return m_cn.alloc_device(sz);
}

void free_workspace(void *ptr, void *user_data) override {
m_cn.free_device(ptr);
}

TensorPtr at(size_t id) {
return m_out[id];
}
};

} // namespace imperative } // namespace imperative
} // namespace mgb
} // namespace mgb

+ 4
- 3
imperative/src/impl/interpreter_impl.cpp View File

@@ -28,13 +28,13 @@ Interpreter& Interpreter::inst() {
return inst_; return inst_;
} }


void* ChannelImpl::put(const HostTensorND& value) {
void* ChannelImpl::put(const HostTensorND& value, bool no_cache) {
auto info = alloc(); auto info = alloc();
info->desc.layout = value.layout(); info->desc.layout = value.layout();
info->desc.comp_node = value.comp_node(); info->desc.comp_node = value.comp_node();
info->desc.value = value.proxy_to_default_cpu(); info->desc.value = value.proxy_to_default_cpu();
m_valid_handle.insert(info); m_valid_handle.insert(info);
m_worker.add_task(Put{info, value});
m_worker.add_task(Put{info, value, no_cache});
return info; return info;
} }


@@ -395,7 +395,8 @@ void ChannelImpl::process_one_task(Command& cmd) {
using T = std::remove_reference_t<decltype(cmd)>; using T = std::remove_reference_t<decltype(cmd)>;
try { try {
if constexpr (std::is_same_v<T, Put>) { if constexpr (std::is_same_v<T, Put>) {
produce_tensor(cmd.dest, Tensor::make(cmd.value));
auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value);
produce_tensor(cmd.dest, std::move(value));
} else if constexpr (std::is_same_v<T, ApplyOp>) { } else if constexpr (std::is_same_v<T, ApplyOp>) {
SmallVector<TensorPtr> tensor_inputs; SmallVector<TensorPtr> tensor_inputs;
tensor_inputs.reserve(cmd.inputs.size()); tensor_inputs.reserve(cmd.inputs.size());


+ 3
- 2
imperative/src/impl/interpreter_impl.h View File

@@ -45,7 +45,7 @@ struct TensorInfo {
HostTensorND h_value; HostTensorND h_value;
size_t locked = 0; size_t locked = 0;
size_t recompute_times = 0; size_t recompute_times = 0;
struct ComputePath { struct ComputePath {
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
SmallVector<TensorInfoPtr> inputs; SmallVector<TensorInfoPtr> inputs;
@@ -57,6 +57,7 @@ struct TensorInfo {
struct Put { struct Put {
TensorInfo* dest; TensorInfo* dest;
HostTensorND value; HostTensorND value;
bool no_cache = false;
}; };
struct ApplyOp { struct ApplyOp {
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
@@ -92,7 +93,7 @@ struct ChannelImpl : Interpreter::Channel {
ChannelImpl() : m_worker(this) {} ChannelImpl() : m_worker(this) {}
~ChannelImpl() override; ~ChannelImpl() override;


Handle put(const HostTensorND& value) override;
Handle put(const HostTensorND& value, bool no_cache) override;
Handle put(const DeviceTensorND& value) override; Handle put(const DeviceTensorND& value) override;


void del(Handle) override; void del(Handle) override;


+ 1
- 39
imperative/src/impl/ops/cond_take.cpp View File

@@ -20,44 +20,6 @@ namespace mgb::imperative {


namespace { namespace {


class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy {
using Output = std::array<TensorPtr, 2>;

CompNode m_cn;
Output m_out;

public:
MegDNNDynOutMallocImpl(CompNode cn): m_cn{cn} {}

megdnn::TensorND alloc_output(
size_t id, DType dtype, const TensorShape &shape,
void *user_data) override;

void* alloc_workspace(size_t sz, void *user_data) override;
void free_workspace(void *ptr, void *user_data) override;
TensorPtr at(size_t id);
};

megdnn::TensorND MegDNNDynOutMallocImpl::alloc_output(
size_t id, DType dtype, const TensorShape &shape,
void * /*user_data*/) {
TensorLayout m_layout(shape, dtype);
m_out[id] = Tensor::make(m_layout, m_cn);
return m_out[id]->dev_tensor().as_megdnn();
}

void* MegDNNDynOutMallocImpl::alloc_workspace(size_t sz, void * /*user_data*/) {
return m_cn.alloc_device(sz);
}

void MegDNNDynOutMallocImpl::free_workspace(void *ptr, void * /*user_data*/) {
m_cn.free_device(ptr);
}

TensorPtr MegDNNDynOutMallocImpl::at(size_t id) {
return m_out[id];
}

cg::OperatorNodeBase* apply_on_var_node( cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
@@ -94,7 +56,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
dtype::Byte()); dtype::Byte());


auto dnn_workspace = dnn_op.create_workspace(m_layout); auto dnn_workspace = dnn_op.create_workspace(m_layout);
MegDNNDynOutMallocImpl policy{inp->comp_node()};
MegDNNDynOutMallocImpl<2> policy{inp->comp_node()};


dnn_op.op->exec(inp->dev_tensor().as_megdnn(), dnn_op.op->exec(inp->dev_tensor().as_megdnn(),
msk->dev_tensor().as_megdnn(), msk->dev_tensor().as_megdnn(),


+ 133
- 0
imperative/src/impl/ops/elemwise.cpp View File

@@ -11,8 +11,11 @@


#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/opr/utility.h"


#include "../op_trait.h" #include "../op_trait.h"
#include "../dnn_op_helper.h"


namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
@@ -84,12 +87,142 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return {Tensor::make(out)}; return {Tensor::make(out)};
} }


MGB_DEFINE_OPR_CLASS(ForceInplaceElemwise, cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) //{
public:
struct Param{
using Mode = megdnn::Elemwise::Param::Mode;
Mode mode;
size_t inplace_index;
};
using Mode = Param::Mode;
ForceInplaceElemwise(const VarNodeArray& inputs, Param param,
OperatorNodeConfig config = {})
: Super(inputs[0]->owner_graph(), config, "device_add_update", inputs), m_param{param} {
for (auto* input: inputs) {
add_input({input});
}
add_output(None)->
set_fwd_in2out_writable_force(input(param.inplace_index)).
add_flag(VarNode::Flag::NO_MEM_RECLAIM);
}
static SymbolVar make(const VarNodeArray& inputs, Param param) {
return SymbolVar{inputs[0]}.insert_single_output_opr<ForceInplaceElemwise>(
inputs, param);
}
static cg::OperatorNodeBase* shallow_copy(
const serialization::OprShallowCopyContext &ctx,
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
const OperatorNodeConfig &config);
protected:
NodeProp* do_make_node_prop() const override {
auto ret = Super::do_make_node_prop();
ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR);
return ret;
}
void create_megdnn_opr() override {
auto opr = DnnOprCaller<megdnn::Elemwise>::create_operator(comp_node());
opr->param().mode = m_param.mode;
set_megdnn_opr(std::move(opr));
}
void scn_do_execute() override {
auto to_dnnnd = [&](auto* var){ return var->dev_tensor().as_megdnn(); };
megdnn::TensorNDArray inputs_dnnnd;
for (auto* input: input()) {
inputs_dnnnd.push_back(to_dnnnd(input));
}
mgb_assert(input(m_param.inplace_index)->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC),
"ForceInplaceElemwise cannot be applied in internal tensor");
auto* out_dest = output(0);
auto* opr = static_cast<megdnn::Elemwise*>(megdnn_opr());
opr->exec(std::move(inputs_dnnnd),
to_dnnnd(out_dest));
}
void init_output_static_infer_desc() override {
using namespace cg::static_infer;

owner_graph()->static_infer_manager().register_shape_infer(
output(0), ShapeInferDesc::make_identity(input(m_param.inplace_index)));
}
private:
Param m_param;
void record_execute_deps(ExecDependencyArray& deps) override {
record_megdnn_opr(deps);
}
};

MGB_DYN_TYPE_OBJ_FINAL_IMPL(ForceInplaceElemwise);

cg::OperatorNodeBase* ForceInplaceElemwise::shallow_copy(
const serialization::OprShallowCopyContext &ctx,
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
const OperatorNodeConfig &config) {
auto &&opr = opr_.cast_final_safe<ForceInplaceElemwise>();
auto* graph = ctx.owner_graph(opr, inputs);
return graph->insert_opr(std::make_unique<ForceInplaceElemwise>(inputs, opr.m_param, config));
}

MGB_REG_OPR_SHALLOW_COPY(ForceInplaceElemwise, ForceInplaceElemwise::shallow_copy);

cg::OperatorNodeBase* apply_inplace_add_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto dest = inputs[0], delta = inputs[1],
alpha = inputs[2], beta = inputs[3];
auto mode = ForceInplaceElemwise::Param::Mode::FUSE_MUL_ADD4;
return ForceInplaceElemwise::make({alpha, dest, beta, delta}, {mode, 1}).node()->owner_opr();
}

SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs){
auto dest = inputs[0], delta = inputs[1],
alpha = inputs[2], beta = inputs[3];
auto tensor_to_scalar = [](const TensorPtr& tensor) -> float {
return *tensor->get_value().ptr<float>();
};
DnnOprCaller<megdnn::AddUpdate> caller{dest->comp_node()};
caller.op->param() = { tensor_to_scalar(alpha), tensor_to_scalar(beta) };
caller.op->exec(dest->dev_tensor().as_megdnn(), delta->dev_tensor().as_megdnn());
return { std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout()) };
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(inputs.size() == 4, "invalid input number for inplace_add");
CompNode cn;
for (auto&& input: inputs) {
if (!cn.valid()) {
cn = input.comp_node;
} else {
mgb_assert(input.comp_node == cn, "inputs should be in same comp_node");
}
}
auto dest = inputs[0], delta = inputs[1],
alpha = inputs[2], beta = inputs[3];
bool succeed = dest.layout.ndim != 0;
if (succeed) {
mgb_assert(delta.layout.ndim == 0 || dest.layout.eq_shape(delta.layout), "dest and delta must have same shape");
mgb_assert(alpha.layout.ndim == 0 || alpha.layout.eq_shape({1}), "alpha should be scalar");
mgb_assert(beta.layout.ndim == 0 || beta.layout.eq_shape({1}), "beta should be scalar");
}
mgb_assert(alpha.layout.dtype == dtype::Float32(), "alpha should be float32");
mgb_assert(beta.layout.dtype == dtype::Float32(), "beta should be float32");
return {{dest}, succeed};
}

OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
.make_from_op_node(make_from_op_node) .make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor) .apply_on_physical_tensor(apply_on_physical_tensor)
.fallback(); .fallback();

OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate)
.apply_on_var_node(apply_inplace_add_on_var_node)
.apply_on_physical_tensor(apply_inplace_add_on_physical_tensor)
.infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible)
.fallback();
} // anonymous namespace } // anonymous namespace


} // namespace imperative } // namespace imperative


+ 14
- 14
imperative/src/impl/proxy_graph_detail.cpp View File

@@ -32,14 +32,22 @@ SmallVector<Tensor*> to_raw_ptr_array(
return ret; return ret;
} }


SmallVector<LogicalTensorDesc>
infer_output_attrs(const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto&& graph = ProxyGraph::get_default_graph();
return graph->infer_output_attrs(def, to_raw_ptr_array(inputs));
}
} // anonymous namespace

void exec(const OpDef& def, void exec(const OpDef& def,
const SmallVector<TensorPtr>& inputs_,
const SmallVector<TensorPtr>& outputs_) {
const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs) {
auto&& graph = ProxyGraph::get_default_graph(); auto&& graph = ProxyGraph::get_default_graph();
auto inputs = to_raw_ptr_array(inputs_),
outputs = to_raw_ptr_array(outputs_);
auto raw_inputs = to_raw_ptr_array(inputs),
raw_outputs = to_raw_ptr_array(outputs);
CompNode::UnorderedSet used_cns; CompNode::UnorderedSet used_cns;
for (auto&& out: outputs) {
for (auto&& out: raw_outputs) {
auto cn = out->comp_node(); auto cn = out->comp_node();
if (used_cns.insert(cn).second) { if (used_cns.insert(cn).second) {
for (auto&& in: inputs) { for (auto&& in: inputs) {
@@ -50,7 +58,7 @@ void exec(const OpDef& def,
} }
} }
} }
graph->invoke_op(def, inputs, outputs);
graph->invoke_op(def, raw_inputs, raw_outputs);
for (auto&& cn: used_cns) { for (auto&& cn: used_cns) {
for (auto&& in: inputs) { for (auto&& in: inputs) {
if (in->comp_node() != cn) { if (in->comp_node() != cn) {
@@ -60,14 +68,6 @@ void exec(const OpDef& def,
} }
} }


SmallVector<LogicalTensorDesc>
infer_output_attrs(const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto&& graph = ProxyGraph::get_default_graph();
return graph->infer_output_attrs(def, to_raw_ptr_array(inputs));
}
} // anonymous namespace

SmallVector<TensorPtr> SmallVector<TensorPtr>
apply_on_physical_tensor(const OpDef& def, apply_on_physical_tensor(const OpDef& def,
const SmallVector<TensorPtr>& inputs) { const SmallVector<TensorPtr>& inputs) {


+ 1
- 1
imperative/src/include/megbrain/imperative/interpreter.h View File

@@ -21,7 +21,7 @@ struct Interpreter {
struct Channel { struct Channel {
virtual ~Channel() = default; virtual ~Channel() = default;


virtual Handle put(const HostTensorND& value) = 0;
virtual Handle put(const HostTensorND& value, bool no_cache) = 0;
virtual Handle put(const DeviceTensorND& value) = 0; virtual Handle put(const DeviceTensorND& value) = 0;


virtual void del(Handle) = 0; virtual void del(Handle) = 0;


+ 4
- 0
imperative/src/include/megbrain/imperative/physical_tensor.h View File

@@ -101,6 +101,10 @@ public:
return m_layout; return m_layout;
} }


size_t offset() const {
return m_offset;
}

DeviceTensorND dev_tensor(); DeviceTensorND dev_tensor();


static TensorPtr make_scalar(DTypeScalar value, CompNode cn); static TensorPtr make_scalar(DTypeScalar value, CompNode cn);


+ 4
- 0
imperative/src/include/megbrain/imperative/proxy_graph_detail.h View File

@@ -24,6 +24,10 @@ apply_on_physical_tensor(const OpDef& def,
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def, std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs); const SmallVector<LogicalTensorDesc>& inputs);


void exec(const OpDef& def,
const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs);

BackwardGraphResult BackwardGraphResult
make_backward_graph(const OpDef& def, make_backward_graph(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs, const SmallVector<LogicalTensorDesc>& inputs,


+ 2
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -239,4 +239,6 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara
); );
} }


def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;

#endif // MGB_OPS #endif // MGB_OPS

+ 2
- 5
src/opr/impl/basic_arith.cpp View File

@@ -886,12 +886,9 @@ AddUpdate::AddUpdate(VarNode *dest, VarNode *delta,
m_param{param} m_param{param}
{ {
auto dest_opr = dest->owner_opr(); auto dest_opr = dest->owner_opr();
mgb_throw_if(!(dest_opr->same_type<SharedDeviceTensor>() ||
dest_opr->same_type<VolatileSharedDeviceTensor>()),
mgb_throw_if(dest_opr->same_type<ImmutableTensor>(),
GraphError, GraphError,
"AddUpdate must be applied on SharedDeviceTensor; "
"got %s{%s} actually",
dest_opr->cname(), dest_opr->dyn_typeinfo()->name);
"AddUpdate cannot be applied on ImmutableTensor; ");
add_input({dest, delta}); add_input({dest, delta});


/* /*


+ 16
- 0
src/opr/impl/internal/identical_fwd.cpp View File

@@ -80,6 +80,22 @@ public:


MGB_TYPEINFO_OBJ_IMPL(ForwardInputToOutput::MutableSrc); MGB_TYPEINFO_OBJ_IMPL(ForwardInputToOutput::MutableSrc);


void ForwardInputToOutput::mixin_init_rt_force_dynamic_mem_alloc_imply_chain(
OperatorNodeBase &opr) {
VarNode *valid_out = nullptr;
for (auto i: opr.output()) {
if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
mgb_assert(!valid_out);
valid_out = i;
}
}
mgb_assert(valid_out);

// There may be many inputs such as in opr::VirtualDep, but we only forward first one
opr.input(0)->add_rt_force_dynamic_mem_alloc_imply_chain(valid_out);
valid_out->add_rt_force_dynamic_mem_alloc_imply_chain(opr.input(0));
}

void ForwardInputToOutput::mixin_mem_plan_fwd_in2out_readonly( void ForwardInputToOutput::mixin_mem_plan_fwd_in2out_readonly(
OperatorNodeBase& opr) { OperatorNodeBase& opr) {
m_mem_fwd_success = opr.output(0)->set_fwd_in2out_readonly( m_mem_fwd_success = opr.output(0)->set_fwd_in2out_readonly(


+ 2
- 2
src/opr/include/megbrain/opr/internal/identical_fwd.h View File

@@ -67,6 +67,7 @@ class ForwardInputToOutput: public cg::OperatorNodeMixinBase {


virtual void mixin_scn_do_execute(OperatorNodeBase &opr); virtual void mixin_scn_do_execute(OperatorNodeBase &opr);


void mixin_init_rt_force_dynamic_mem_alloc_imply_chain(OperatorNodeBase &opr);
void mixin_mem_plan_fwd_in2out_readonly(OperatorNodeBase &opr); void mixin_mem_plan_fwd_in2out_readonly(OperatorNodeBase &opr);
void mixin_init_output_static_infer_desc(OperatorNodeBase &opr); void mixin_init_output_static_infer_desc(OperatorNodeBase &opr);
virtual cg::static_infer::ValueInferDesc mixin_get_static_infer_desc(OperatorNodeBase &opr); virtual cg::static_infer::ValueInferDesc mixin_get_static_infer_desc(OperatorNodeBase &opr);
@@ -173,8 +174,7 @@ MGB_DEFINE_CLS_WITH_SUPER(ForwardInputToOutput,
protected: protected:
using Super::Super; using Super::Super;
void init_rt_force_dynamic_mem_alloc_imply_chain() override { void init_rt_force_dynamic_mem_alloc_imply_chain() override {
mixin::init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o(
*this);
this->mixin_init_rt_force_dynamic_mem_alloc_imply_chain(*this);
} }


void mem_plan_fwd_in2out_readonly() override { void mem_plan_fwd_in2out_readonly() override {


Loading…
Cancel
Save