GitOrigin-RevId: d093778e10
release-0.6
@@ -213,6 +213,15 @@ if(MGE_WITH_TEST) | |||||
endif() | endif() | ||||
option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON) | option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON) | ||||
option(MGE_BUILD_XXX "Build _xxx.so instead of mgb.so " OFF) | |||||
if(MGE_BUILD_XXX) | |||||
set(CMAKE_CXX_STANDARD 17) | |||||
endif() | |||||
option(MGE_BUILD_SDK "Build load_and_run" ON) | |||||
if(MGE_BUILD_XXX) | |||||
set(MGE_BUILD_SDK OFF) | |||||
endif() | |||||
if(NOT MGE_WITH_CUDA) | if(NOT MGE_WITH_CUDA) | ||||
message("-- Disable distributed support, as CUDA is not enabled.") | message("-- Disable distributed support, as CUDA is not enabled.") | ||||
@@ -522,7 +531,7 @@ endif() | |||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}") | set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}") | ||||
set(MGB_ENABLE_IMPERATIVE, ${MGE_BUILD_XXX}) | |||||
# Write out megbrain_build_config.h | # Write out megbrain_build_config.h | ||||
# It defines macros needed by both megbrain and dnn | # It defines macros needed by both megbrain and dnn | ||||
configure_file(src/megbrain_build_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/genfiles/megbrain_build_config.h) | configure_file(src/megbrain_build_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/genfiles/megbrain_build_config.h) | ||||
@@ -566,14 +575,23 @@ if(MGE_WITH_DISTRIBUTED) | |||||
endif() | endif() | ||||
add_subdirectory(src) | add_subdirectory(src) | ||||
add_subdirectory(sdk/load-and-run) | |||||
if(MGE_BUILD_SDK) | |||||
add_subdirectory(sdk/load-and-run) | |||||
endif() | |||||
if(MGE_WITH_PYTHON_MODULE) | if(MGE_WITH_PYTHON_MODULE) | ||||
add_subdirectory(python_module) | |||||
if(MGE_BUILD_XXX) | |||||
add_subdirectory(imperative) | |||||
else() | |||||
add_subdirectory(python_module) | |||||
endif() | |||||
endif() | endif() | ||||
if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) | if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) | ||||
add_subdirectory(test) | |||||
if(NOT MGE_BUILD_XXX) | |||||
add_subdirectory(test) | |||||
endif() | |||||
endif() | endif() | ||||
if(TARGET mgb) | if(TARGET mgb) | ||||
@@ -597,6 +615,21 @@ if(TARGET mgb) | |||||
DEPENDS mgb | DEPENDS mgb | ||||
VERBATIM | VERBATIM | ||||
) | ) | ||||
elseif(TARGET _xxx) | |||||
add_custom_target( | |||||
develop | |||||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||||
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/$<TARGET_FILE_NAME:${MODULE_NAME}> | |||||
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/$<TARGET_FILE_NAME:${MODULE_NAME}> | |||||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||||
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/ops/_internal/generated_ops.py | |||||
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/ops/_internal/generated_ops.py | |||||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||||
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/ops/_internal/param_defs.py | |||||
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/ops/_internal/param_defs.py | |||||
DEPENDS _xxx | |||||
VERBATIM | |||||
) | |||||
endif() | endif() | ||||
IF(APPLE) | IF(APPLE) | ||||
@@ -59,7 +59,9 @@ install(TARGETS opr_param_defs EXPORT ${MGE_EXPORT_TARGETS}) | |||||
if(MGE_WITH_TEST) | if(MGE_WITH_TEST) | ||||
add_subdirectory(test) | |||||
if(NOT MGE_BUILD_XXX) | |||||
add_subdirectory(test) | |||||
endif() | |||||
endif() | endif() | ||||
add_subdirectory(src) | add_subdirectory(src) | ||||
@@ -298,6 +298,9 @@ class PyWriter(IndentWriterBase): | |||||
_enum_member2num = None | _enum_member2num = None | ||||
def __init__(self, for_imperative=False): | |||||
self._imperative = for_imperative | |||||
def __call__(self, fout, defs): | def __call__(self, fout, defs): | ||||
super().__call__(fout) | super().__call__(fout) | ||||
self._enum_member2num = [] | self._enum_member2num = [] | ||||
@@ -339,19 +342,35 @@ class PyWriter(IndentWriterBase): | |||||
' return super()._missing_(value)\n' | ' return super()._missing_(value)\n' | ||||
'\n' | '\n' | ||||
) | ) | ||||
self._write( | |||||
'def _as_dtype_num(dtype):\n' | |||||
' import megengine._internal.mgb as m\n' | |||||
' return m._get_dtype_num(dtype)\n' | |||||
'\n' | |||||
) | |||||
self._write( | |||||
''' | |||||
def _as_serialized_dtype(dtype): | |||||
import megengine._internal.mgb as m | |||||
return m._get_serialized_dtype(dtype) | |||||
''' | |||||
) | |||||
if not self._imperative: | |||||
self._write( | |||||
'def _as_dtype_num(dtype):\n' | |||||
' import megengine._internal.mgb as m\n' | |||||
' return m._get_dtype_num(dtype)\n' | |||||
'\n' | |||||
) | |||||
self._write( | |||||
'def _as_serialized_dtype(dtype):\n' | |||||
' import megengine._internal.mgb as m\n' | |||||
' return m._get_serialized_dtype(dtype)\n' | |||||
'\n' | |||||
) | |||||
else: | |||||
self._write( | |||||
'def _as_dtype_num(dtype):\n' | |||||
' import xxx._xxx.utils as m\n' | |||||
' return m._get_dtype_num(dtype)\n' | |||||
'\n' | |||||
) | |||||
self._write( | |||||
'def _as_serialized_dtype(dtype):\n' | |||||
' import xxx._xxx.utils as m\n' | |||||
' return m._get_serialized_dtype(dtype)\n' | |||||
'\n' | |||||
) | |||||
self._process(defs) | self._process(defs) | ||||
self._write( | self._write( | ||||
''' | ''' | ||||
@@ -777,8 +796,12 @@ def main(): | |||||
'cpp file') | 'cpp file') | ||||
parser.add_argument('input') | parser.add_argument('input') | ||||
parser.add_argument('output') | parser.add_argument('output') | ||||
parser.add_argument('--imperative', action='store_true', | |||||
help='generate files for imperatvie ') | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
for_imperative = args.imperative | |||||
with open(args.input) as fin: | with open(args.input) as fin: | ||||
inputs = fin.read() | inputs = fin.read() | ||||
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | ||||
@@ -787,7 +810,7 @@ def main(): | |||||
input_hash = input_hash.hexdigest() | input_hash = input_hash.hexdigest() | ||||
if args.type == 'py': | if args.type == 'py': | ||||
writer = PyWriter() | |||||
writer = PyWriter(for_imperative=for_imperative) | |||||
else: | else: | ||||
assert args.type == 'c++' | assert args.type == 'c++' | ||||
if args.enumv: | if args.enumv: | ||||
@@ -151,27 +151,31 @@ if(ANDROID) | |||||
target_link_libraries(megbrain PUBLIC log) | target_link_libraries(megbrain PUBLIC log) | ||||
endif() | endif() | ||||
# Build as SHARED or STATIC depending on BUILD_SHARED_LIBS=ON/OFF | |||||
add_library(megengine) | |||||
target_link_libraries(megengine PUBLIC megbrain megdnn) | |||||
if (UNIX AND NOT APPLE) | |||||
# TODO: Use target_link_options after upgrading to CMake 3.13 | |||||
target_link_options(megengine PRIVATE -Wl,--no-undefined -Wl,--version-script=${PROJECT_SOURCE_DIR}/python_module/src/version.ld) | |||||
endif() | |||||
set_target_properties(megengine PROPERTIES CXX_VISIBILITY_PRESET default) | |||||
set_target_properties(megengine PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE) | |||||
if (MGE_WITH_DISTRIBUTED) | |||||
if(NOT MGE_BUILD_XXX) | |||||
# Build as SHARED or STATIC depending on BUILD_SHARED_LIBS=ON/OFF | |||||
add_library(megengine) | |||||
target_link_libraries(megengine PUBLIC megbrain megdnn) | |||||
if (UNIX AND NOT APPLE) | |||||
# TODO: Use target_link_options after upgrading to CMake 3.13 | |||||
# FIXME; Please use right directory for mgb or imperative | |||||
target_link_options(megengine PRIVATE -Wl,--no-undefined -Wl,--version-script=${PROJECT_SOURCE_DIR}/python_module/src/version.ld) | |||||
endif() | |||||
set_target_properties(megengine PROPERTIES CXX_VISIBILITY_PRESET default) | |||||
set_target_properties(megengine PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE) | |||||
# Do not export targets if MGE_WITH_DISTRIBUTED is on. MegRay is not ready | # Do not export targets if MGE_WITH_DISTRIBUTED is on. MegRay is not ready | ||||
# for this. | # for this. | ||||
install(TARGETS megengine | install(TARGETS megengine | ||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} | LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} | ||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) | ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) | ||||
else() | |||||
install(TARGETS megengine megbrain | |||||
endif() | |||||
if (NOT MGE_WITH_DISTRIBUTED) | |||||
install(TARGETS megbrain | |||||
EXPORT ${MGE_EXPORT_TARGETS} | EXPORT ${MGE_EXPORT_TARGETS} | ||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} | LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} | ||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) | ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) | ||||
endif() | endif() | ||||
foreach(_PATH ${MGB_INC}) | foreach(_PATH ${MGB_INC}) | ||||
install(DIRECTORY ${_PATH}/megbrain DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} FILES_MATCHING PATTERN "*.h") | install(DIRECTORY ${_PATH}/megbrain DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} FILES_MATCHING PATTERN "*.h") | ||||
endforeach() | endforeach() |
@@ -271,6 +271,23 @@ OperatorNodeBase* ComputingGraphImpl::insert_opr( | |||||
std::unique_ptr<OperatorNodeBase> opr_uniqp) { | std::unique_ptr<OperatorNodeBase> opr_uniqp) { | ||||
auto opr = opr_uniqp.get(); | auto opr = opr_uniqp.get(); | ||||
if (options().imperative_proxy_graph) { | |||||
if (!opr->inserted_in_graph()) { | |||||
m_opr_refkeeper.emplace_back(std::move(opr_uniqp)); | |||||
opr->set_inserted_in_graph(); | |||||
opr->init_output_comp_node(); | |||||
opr->init_output_dtype(); | |||||
opr->init_output_format(); | |||||
// register static infer | |||||
{ | |||||
auto&& mgr = static_infer_manager_impl(); | |||||
auto old = mgr.set_register_allowed_opr(opr); | |||||
opr->init_output_static_infer_desc(); | |||||
mgr.set_register_allowed_opr(old); | |||||
} | |||||
} | |||||
return opr; | |||||
} | |||||
if (opr->inserted_in_graph()) { | if (opr->inserted_in_graph()) { | ||||
// FIXME: it's just a trick used for re-evaluation in eager evaluation | // FIXME: it's just a trick used for re-evaluation in eager evaluation | ||||
// mode. Since comp_graph has already taken an ownership of the opr, | // mode. Since comp_graph has already taken an ownership of the opr, | ||||
@@ -133,6 +133,15 @@ void cg::register_grad_func(Typeinfo *opr_type, OprGradFunc grad) { | |||||
opr_type->name); | opr_type->name); | ||||
} | } | ||||
OprGradFunc* cg::lookup_grad_func(Typeinfo *opr_type) { | |||||
auto giter = static_data().grad_func_registry.find(opr_type); | |||||
if (giter != static_data().grad_func_registry.end()) { | |||||
return &giter->second; | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
class GradManager::StreamStrongPropInfer { | class GradManager::StreamStrongPropInfer { | ||||
DepOprIter m_opr_iter; | DepOprIter m_opr_iter; | ||||
ThinHashSet<OperatorNodeBase*> m_strong_oprs; | ThinHashSet<OperatorNodeBase*> m_strong_oprs; | ||||
@@ -101,6 +101,11 @@ OperatorNodeBase::~OperatorNodeBase() noexcept { | |||||
} | } | ||||
void OperatorNodeBase::execute(ExecEnv &env) { | void OperatorNodeBase::execute(ExecEnv &env) { | ||||
if (owner_graph()->options().imperative_proxy_graph) { | |||||
do_execute(env); | |||||
return; | |||||
} | |||||
owner_graph()->event().signal_inplace<event::OprExecStart>(this, &env); | owner_graph()->event().signal_inplace<event::OprExecStart>(this, &env); | ||||
// dispatch waiting commands | // dispatch waiting commands | ||||
@@ -230,6 +230,9 @@ VarNode& VarNode::format(TensorFormat format) { | |||||
bool VarNode::set_fwd_in2out_readonly( | bool VarNode::set_fwd_in2out_readonly( | ||||
VarNode *input, const SubTensorSpec &sub) { | VarNode *input, const SubTensorSpec &sub) { | ||||
if (owner_graph()->options().imperative_proxy_graph) { | |||||
return false; | |||||
} | |||||
return static_cast<ComputingGraphImpl*>(owner_graph()) | return static_cast<ComputingGraphImpl*>(owner_graph()) | ||||
->var_node_mem_manager().fwd_in2out_readonly(input, sub, this); | ->var_node_mem_manager().fwd_in2out_readonly(input, sub, this); | ||||
} | } | ||||
@@ -242,6 +245,7 @@ VarNode& VarNode::set_fwd_in2out_writable(VarNode *input) { | |||||
VarNode& VarNode::set_fwd_in2out_writable_force(VarNode *input) { | VarNode& VarNode::set_fwd_in2out_writable_force(VarNode *input) { | ||||
mgb_assert(!owner_graph()->options().imperative_proxy_graph); | |||||
static_cast<ComputingGraphImpl*>(owner_graph()) | static_cast<ComputingGraphImpl*>(owner_graph()) | ||||
->var_node_mem_manager().fwd_in2out_writable_force(input, this); | ->var_node_mem_manager().fwd_in2out_writable_force(input, this); | ||||
return *this; | return *this; | ||||
@@ -440,6 +440,8 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||||
bool eager_evaluation = false; | bool eager_evaluation = false; | ||||
#endif | #endif | ||||
bool imperative_proxy_graph = false; | |||||
//! add extra deps for the comp seq if a specific var is dependent | //! add extra deps for the comp seq if a specific var is dependent | ||||
ThinHashMap<VarNode*, VarNodeArray> extra_vardeps; | ThinHashMap<VarNode*, VarNodeArray> extra_vardeps; | ||||
@@ -74,6 +74,11 @@ namespace cg { | |||||
void register_grad_func(Typeinfo *opr_type, OprGradFunc grad); | void register_grad_func(Typeinfo *opr_type, OprGradFunc grad); | ||||
/*! | /*! | ||||
* \brief lookup grad func for an operator type | |||||
*/ | |||||
OprGradFunc* lookup_grad_func(Typeinfo *opr_type); | |||||
/*! | |||||
* \brief add a callback to be invoked when grad of given var is computed | * \brief add a callback to be invoked when grad of given var is computed | ||||
* | * | ||||
* All transformers would be chained in their added order, and the last | * All transformers would be chained in their added order, and the last | ||||
@@ -69,6 +69,10 @@ class OperatorNodeConfig final: public Hashable { | |||||
return *this; | return *this; | ||||
} | } | ||||
const Maybe<std::string>& name() const { | |||||
return m_name; | |||||
} | |||||
/*! | /*! | ||||
* \brief update instance ID | * \brief update instance ID | ||||
* | * | ||||
@@ -22,6 +22,10 @@ | |||||
#include <mutex> | #include <mutex> | ||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | |||||
class ProxyGraph; | |||||
} // namespace imperative | |||||
namespace cg { | namespace cg { | ||||
namespace static_infer { | namespace static_infer { | ||||
class StaticInferManagerImpl; | class StaticInferManagerImpl; | ||||
@@ -576,6 +580,7 @@ class VarNode final: public GraphNodeBase { | |||||
friend class VarDevMemDefragmenter; | friend class VarDevMemDefragmenter; | ||||
friend class EagerEvalManager; | friend class EagerEvalManager; | ||||
friend class MemAllocPlan; | friend class MemAllocPlan; | ||||
friend class imperative::ProxyGraph; | |||||
}; | }; | ||||
enum class VarNode::Flag: uint32_t { | enum class VarNode::Flag: uint32_t { | ||||
@@ -29,6 +29,8 @@ | |||||
#cmakedefine01 MGB_ENABLE_FBS_SERIALIZATION | #cmakedefine01 MGB_ENABLE_FBS_SERIALIZATION | ||||
#cmakedefine01 MGB_IS_DEV | #cmakedefine01 MGB_IS_DEV | ||||
#cmakedefine01 MGB_ENABLE_IMPERATIVE | |||||
// DNN related flags | // DNN related flags | ||||
// Platform macro's | // Platform macro's | ||||
#cmakedefine01 MEGDNN_WITH_CUDA | #cmakedefine01 MEGDNN_WITH_CUDA | ||||
@@ -40,29 +40,37 @@ BatchNormForward::BatchNormForward(VarNode *x, | |||||
Super{x->owner_graph(), config, "batch_norm", | Super{x->owner_graph(), config, "batch_norm", | ||||
{x, scale, bias, mean, variance}} | {x, scale, bias, mean, variance}} | ||||
{ | { | ||||
auto check_dest = [&](VarNode* dest) { | |||||
auto dest_opr = dest->owner_opr(); | |||||
mgb_throw_if(!(dest_opr->same_type<SharedDeviceTensor>() || | |||||
dest_opr->same_type<VolatileSharedDeviceTensor>()), | |||||
GraphError, | |||||
"mean&variance in BatchNorm must be SharedDeviceTensor/VolatileSharedDeviceTensor; " | |||||
"got %s{%s} actually", | |||||
dest_opr->cname(), dest_opr->dyn_typeinfo()->name); | |||||
}; | |||||
check_dest(mean); | |||||
check_dest(variance); | |||||
if(owner_graph()->options().imperative_proxy_graph) { | |||||
m_force_inplace = false; | |||||
} | |||||
if (m_force_inplace) { | |||||
auto check_dest = [&](VarNode* dest) { | |||||
auto dest_opr = dest->owner_opr(); | |||||
mgb_throw_if(!(dest_opr->same_type<SharedDeviceTensor>() || | |||||
dest_opr->same_type<VolatileSharedDeviceTensor>()), | |||||
GraphError, | |||||
"mean and variance in BatchNorm must be SharedDeviceTensor " | |||||
"or VolatileSharedDeviceTensor; got %s{%s} actually", | |||||
dest_opr->cname(), dest_opr->dyn_typeinfo()->name); | |||||
}; | |||||
check_dest(mean); | |||||
check_dest(variance); | |||||
} | |||||
init_megdnn_opr(*this, param); | init_megdnn_opr(*this, param); | ||||
add_input({x, scale, bias, mean, variance}); | add_input({x, scale, bias, mean, variance}); | ||||
output(0)-> | |||||
set_fwd_in2out_writable_force(input(3)). | |||||
add_flag(VarNode::Flag::NO_MEM_RECLAIM); | |||||
if (m_force_inplace) { | |||||
output(0)-> | |||||
set_fwd_in2out_writable_force(input(3)). | |||||
add_flag(VarNode::Flag::NO_MEM_RECLAIM); | |||||
output(1)-> | |||||
set_fwd_in2out_writable_force(input(4)). | |||||
add_flag(VarNode::Flag::NO_MEM_RECLAIM); | |||||
output(1)-> | |||||
set_fwd_in2out_writable_force(input(4)). | |||||
add_flag(VarNode::Flag::NO_MEM_RECLAIM); | |||||
} | |||||
} | } | ||||
BatchNormForward::BatchNormForward(VarNode *x, | BatchNormForward::BatchNormForward(VarNode *x, | ||||
@@ -129,17 +137,40 @@ BatchNormForward::do_make_node_prop() const { | |||||
void BatchNormForward::scn_do_execute() { | void BatchNormForward::scn_do_execute() { | ||||
auto &&x = input(0)->dev_tensor(); | auto &&x = input(0)->dev_tensor(); | ||||
auto &&y = output(4)->dev_tensor(); | |||||
mgb_assert(x.layout().is_contiguous() && | |||||
y.layout().is_contiguous()); | |||||
#if MGB_ENABLE_IMPERATIVE | |||||
if (input().size() == 5) { // need running mean/variance | |||||
auto &&o0 = output(0)->dev_tensor(), | |||||
&&o1 = output(1)->dev_tensor(), | |||||
&&i0 = input(3)->dev_tensor(), | |||||
&&i1 = input(4)->dev_tensor(); | |||||
mgb_assert(o0.raw_ptr() && o1.raw_ptr()); // non-empty tensor | |||||
mgb_assert(o0.comp_node() == i0.comp_node() && | |||||
o1.comp_node() == i1.comp_node() && | |||||
o0.layout().eq_layout(i0.layout()) && | |||||
o1.layout().eq_layout(i1.layout())); | |||||
if (!m_force_inplace) { | |||||
if (o0.raw_ptr() != i0.raw_ptr()) { | |||||
o0.copy_from_fixlayout(i0); | |||||
} | |||||
if (o1.raw_ptr() != i1.raw_ptr()) { | |||||
o1.copy_from_fixlayout(i1); | |||||
} | |||||
} else { | |||||
mgb_assert(o0.raw_ptr() == i0.raw_ptr() | |||||
&& o1.raw_ptr() == i1.raw_ptr()); | |||||
} | |||||
} | |||||
#endif | |||||
auto scale = input(1)->dev_tensor().as_megdnn(); | auto scale = input(1)->dev_tensor().as_megdnn(); | ||||
auto bias = input(2)->dev_tensor().as_megdnn(); | auto bias = input(2)->dev_tensor().as_megdnn(); | ||||
auto mean = output(0)->dev_tensor().as_megdnn(); | auto mean = output(0)->dev_tensor().as_megdnn(); | ||||
auto variance = output(1)->dev_tensor().as_megdnn(); | auto variance = output(1)->dev_tensor().as_megdnn(); | ||||
auto save_mean = output(2)->dev_tensor().as_megdnn(); | auto save_mean = output(2)->dev_tensor().as_megdnn(); | ||||
auto save_variance = output(3)->dev_tensor().as_megdnn(); | auto save_variance = output(3)->dev_tensor().as_megdnn(); | ||||
auto &&y = output(4)->dev_tensor(); | |||||
auto workspace = intl::get_megdnn_workspace_from_var( | |||||
output().back()); | |||||
mgb_assert(x.layout().is_contiguous() && | |||||
y.layout().is_contiguous()); | |||||
auto workspace = intl::get_megdnn_workspace_from_var(output().back()); | |||||
megdnn_opr()->exec(x.as_megdnn(), scale, bias, mean, variance, | megdnn_opr()->exec(x.as_megdnn(), scale, bias, mean, variance, | ||||
save_mean, save_variance, y.as_megdnn(), workspace); | save_mean, save_variance, y.as_megdnn(), workspace); | ||||
} | } | ||||
@@ -191,6 +222,14 @@ void BatchNormForward::init_output_dtype() { | |||||
} | } | ||||
} | } | ||||
void BatchNormForward::mem_plan_fwd_in2out_writable() { | |||||
if (!m_force_inplace && input().size() == 5) { | |||||
// TODO: testing | |||||
output(0)->set_fwd_in2out_writable(input(3)); | |||||
output(1)->set_fwd_in2out_writable(input(4)); | |||||
} | |||||
} | |||||
MGB_IMPL_OPR_GRAD(BatchNormForward) { | MGB_IMPL_OPR_GRAD(BatchNormForward) { | ||||
mgb_assert(wrt_idx < 5); | mgb_assert(wrt_idx < 5); | ||||
if (wrt_idx < 3) { | if (wrt_idx < 3) { | ||||
@@ -271,17 +271,26 @@ WorkspaceLimitGetter::get_impl(ComputingGraph *graph) { | |||||
size_t WorkspaceLimitGetter::get_workspace_limit( | size_t WorkspaceLimitGetter::get_workspace_limit( | ||||
ComputingGraph *graph, CompNode cn, size_t old_limit) { | ComputingGraph *graph, CompNode cn, size_t old_limit) { | ||||
if (graph->options().imperative_proxy_graph) { | |||||
return old_limit; | |||||
} | |||||
if (!graph->options().seq_opt.enable_mem_reuse_alloc) | if (!graph->options().seq_opt.enable_mem_reuse_alloc) | ||||
return old_limit; | return old_limit; | ||||
return get_impl(graph)->get_workspace_limit(cn, old_limit); | return get_impl(graph)->get_workspace_limit(cn, old_limit); | ||||
} | } | ||||
bool WorkspaceLimitGetter::is_prealloc_run(ComputingGraph* graph) { | bool WorkspaceLimitGetter::is_prealloc_run(ComputingGraph* graph) { | ||||
if (graph->options().imperative_proxy_graph) { | |||||
return false; | |||||
} | |||||
return graph->options().seq_opt.enable_mem_reuse_alloc && | return graph->options().seq_opt.enable_mem_reuse_alloc && | ||||
get_impl(graph)->is_prealloc_run(); | get_impl(graph)->is_prealloc_run(); | ||||
} | } | ||||
VarNode* WorkspaceLimitGetter::register_to_graph(ComputingGraph *graph) { | VarNode* WorkspaceLimitGetter::register_to_graph(ComputingGraph *graph) { | ||||
if (graph->options().imperative_proxy_graph) { | |||||
return nullptr; | |||||
} | |||||
auto maker = [graph](){ | auto maker = [graph](){ | ||||
return std::make_shared<Impl>(graph); | return std::make_shared<Impl>(graph); | ||||
}; | }; | ||||
@@ -75,6 +75,10 @@ MGB_DEFINE_OPR_CLASS(BatchNormForward, | |||||
const TensorShapeArray &output_shapes) const override; | const TensorShapeArray &output_shapes) const override; | ||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
void init_output_dtype() override; | void init_output_dtype() override; | ||||
void mem_plan_fwd_in2out_writable() override; | |||||
// if set to True, running mean/variance will be updated inplace | |||||
bool m_force_inplace = true; | |||||
}; | }; | ||||
using BatchNorm = BatchNormForward; | using BatchNorm = BatchNormForward; | ||||