Browse Source

feat(dnn/mge): add compnode multithread in python

GitOrigin-RevId: 47373d291d
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
8cf7150df0
6 changed files with 73 additions and 46 deletions
  1. +4
    -1
      python_module/megengine/core/device.py
  2. +6
    -5
      sdk/load-and-run/src/mgblar.cpp
  3. +27
    -7
      src/core/impl/comp_node/comp_node.cpp
  4. +5
    -7
      src/core/impl/comp_node/cpu/comp_node.cpp
  5. +7
    -3
      src/core/include/megbrain/comp_node.h
  6. +24
    -23
      src/core/test/comp_node.cpp

+ 4
- 1
python_module/megengine/core/device.py View File

@@ -38,9 +38,12 @@ def set_default_device(device: str = "xpux"):


:param device: default device type. The type can be 'cpu0', 'cpu1', etc., :param device: default device type. The type can be 'cpu0', 'cpu1', etc.,
or 'gpu0', 'gpu1', etc., to specify the particular cpu or gpu to use. or 'gpu0', 'gpu1', etc., to specify the particular cpu or gpu to use.
To specify multiple devices, use cpu0:1 or gpu0:2.
'cpux' and 'gupx' can also be used to specify any number of cpu or gpu devices. 'cpux' and 'gupx' can also be used to specify any number of cpu or gpu devices.


'multithread' device type is avaliable when inference, which implements
multi-threading parallelism at the operator level. For example,
'multithread4' will compute with 4 threads. which implements

The default value is 'xpux' to specify any device available. The default value is 'xpux' to specify any device available.


It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. It can also be set by environmental variable `MGE_DEFAULT_DEVICE`.


+ 6
- 5
sdk/load-and-run/src/mgblar.cpp View File

@@ -603,11 +603,11 @@ Args Args::from_argv(int argc, char **argv) {
++ i; ++ i;
ret.multithread_number = std::stoi(argv[i]); ret.multithread_number = std::stoi(argv[i]);
ret.load_config.comp_node_mapper = ret.load_config.comp_node_mapper =
[nr_thread =
[nr_threads =
ret.multithread_number](CompNode::Locator& loc) { ret.multithread_number](CompNode::Locator& loc) {
loc.type = CompNode::DeviceType::MULTITHREAD; loc.type = CompNode::DeviceType::MULTITHREAD;
loc.device = 0; loc.device = 0;
loc.stream = nr_thread;
loc.nr_threads = nr_threads;
}; };
continue; continue;
} }
@@ -615,11 +615,12 @@ Args Args::from_argv(int argc, char **argv) {
mgb_log_warn("use multithread:default mode"); mgb_log_warn("use multithread:default mode");
++i; ++i;
ret.multithread_number = std::stoi(argv[i]); ret.multithread_number = std::stoi(argv[i]);
ret.load_config.comp_node_mapper = [nr_thread =
ret.multithread_number](CompNode::Locator& loc) {
ret.load_config.comp_node_mapper = [nr_threads =
ret.multithread_number](
CompNode::Locator& loc) {
loc.type = CompNode::DeviceType::MULTITHREAD; loc.type = CompNode::DeviceType::MULTITHREAD;
loc.device = CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; loc.device = CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
loc.stream = nr_thread;
loc.nr_threads = nr_threads;
}; };
continue; continue;
} }


+ 27
- 7
src/core/impl/comp_node/comp_node.cpp View File

@@ -127,13 +127,19 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) {
// current parsing location // current parsing location
const char *ptr = id.data(); const char *ptr = id.data();
if (id == "cpu:default") { if (id == "cpu:default") {
return {DeviceType::CPU, DEVICE_CPU_DEFAULT, 0};
return {DeviceType::CPU, DEVICE_CPU_DEFAULT, {0}};
} }
if (!strncmp(ptr, "multithread:default", 19)) { if (!strncmp(ptr, "multithread:default", 19)) {
//! the multithread default compnode string like "multithread:default:x" //! the multithread default compnode string like "multithread:default:x"
ptr += 20;
int nr_thread =std::stoi(ptr);
return {DeviceType::MULTITHREAD, DEVICE_MULTITHREAD_DEFAULT, nr_thread};
if (id.size() > 20) {
ptr += 20;
int nr_thread = std::stoi(ptr);
return {DeviceType::MULTITHREAD,
DEVICE_MULTITHREAD_DEFAULT,
{nr_thread}};
} else {
err();
}
} }


DeviceType dev_type; DeviceType dev_type;
@@ -192,8 +198,16 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) {
int num_stream = parse_int(); int num_stream = parse_int();
if (*ptr) if (*ptr)
err(); err();
//! multi thread with thread number(num_stream) being zero is illegal
if (dev_type == DeviceType::MULTITHREAD) {
if (num_dev == 0) {
err();
}
//! num_steam store the nr_thread
std::swap(num_dev, num_stream);
}


return {dev_type, num_dev, num_stream};
return {dev_type, num_dev, {num_stream}};
} }


void CompNode::Locator::set_device_map(DeviceType type, int from, int to) { void CompNode::Locator::set_device_map(DeviceType type, int from, int to) {
@@ -242,16 +256,22 @@ CompNode::Locator CompNode::Locator::to_physical() const {
stream_physical = 1023; stream_physical = 1023;
} }
} }
return {type_physical, device_physical, stream_physical};
return {type_physical, device_physical, {stream_physical}};
} }


std::string CompNode::Locator::to_string() const { std::string CompNode::Locator::to_string() const {
if (device == DEVICE_CPU_DEFAULT) { if (device == DEVICE_CPU_DEFAULT) {
return "cpu:default"; return "cpu:default";
} else if (device == DEVICE_MULTITHREAD_DEFAULT) { } else if (device == DEVICE_MULTITHREAD_DEFAULT) {
std::string ret="multithread:default:";
std::string ret = "multithread:default:";
ret.append(get_stream_str(stream)); ret.append(get_stream_str(stream));
return ret; return ret;
} else if (type == DeviceType::MULTITHREAD) {
std::string ret("multithread");
ret.append(get_stream_str(stream))
.append(":")
.append(get_stream_str(device));
return ret;
} }
char numstr[32]; char numstr[32];
if (device == -1) { if (device == -1) {


+ 5
- 7
src/core/impl/comp_node/cpu/comp_node.cpp View File

@@ -380,9 +380,9 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
m_locator_logical(locator_logical) { m_locator_logical(locator_logical) {
auto cn = make_comp_node_from_impl(this); auto cn = make_comp_node_from_impl(this);
if (locator.type == DeviceType::MULTITHREAD) { if (locator.type == DeviceType::MULTITHREAD) {
//! When multi-thread the stream stand for thread number
m_thread_pool = std::unique_ptr<ThreadPool>(
new ThreadPool(static_cast<size_t>(locator.stream)));
m_thread_pool = std::unique_ptr<ThreadPool>(new ThreadPool(
static_cast<size_t>(locator.nr_threads)));
mgb_assert(m_thread_pool, "ThradPool create failed");
} }


if (locator.type == DeviceType::CPU) { if (locator.type == DeviceType::CPU) {
@@ -398,7 +398,6 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
cn); cn);
} }
} else if (locator.type == DeviceType::MULTITHREAD) { } else if (locator.type == DeviceType::MULTITHREAD) {
mgb_assert(m_thread_pool, "ThradPool create failed");
if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) { if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) {
m_env.init_cpu( m_env.init_cpu(
{std::make_shared<InplaceCPUDispatcher>( {std::make_shared<InplaceCPUDispatcher>(
@@ -745,15 +744,14 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator,
} else { } else {
mgb_assert(locator.type == DeviceType::MULTITHREAD); mgb_assert(locator.type == DeviceType::MULTITHREAD);
auto&& pqueue_weak = sm_pool->physical2queue_multithead[{ auto&& pqueue_weak = sm_pool->physical2queue_multithead[{
locator.device, locator.stream}];
locator.device, locator.nr_threads}];
auto pqueue = pqueue_weak.lock(); auto pqueue = pqueue_weak.lock();
if (!pqueue) { if (!pqueue) {
pqueue = std::make_shared<WorkerQueue>(locator); pqueue = std::make_shared<WorkerQueue>(locator);
pqueue_weak = pqueue; pqueue_weak = pqueue;
} }
auto&& pimpl = sm_pool->logical2impl_multi_thread[{ auto&& pimpl = sm_pool->logical2impl_multi_thread[{
static_cast<int>(compact_logical_device),
locator_logical.stream}];
compact_logical_device, locator_logical.nr_threads}];
if (!pimpl) { if (!pimpl) {
mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE,
"too many cpu multithread comp nodes; max %d allowed", "too many cpu multithread comp nodes; max %d allowed",


+ 7
- 3
src/core/include/megbrain/comp_node.h View File

@@ -153,8 +153,12 @@ class CompNode {
int device = -1; int device = -1;


//! multiple streams can execute on one computing device and share //! multiple streams can execute on one computing device and share
//! memory
int stream = 0;
//! memory, when compnode type is multithread the field also stand
//! for nr_threads
union {
int stream = 0;
int nr_threads;
};


/*! /*!
* \brief parse a string identifier * \brief parse a string identifier
@@ -162,7 +166,7 @@ class CompNode {
* currently supported ID format: (gpu|cpu)<n>[:m] where n is the * currently supported ID format: (gpu|cpu)<n>[:m] where n is the
* device number, possibly with m as the stream id. * device number, possibly with m as the stream id.
*/ */
static Locator parse(const std::string &id);
static Locator parse(const std::string& id);


/*! /*!
* \brief set mapping between device numbers of a device type * \brief set mapping between device numbers of a device type


+ 24
- 23
src/core/test/comp_node.cpp View File

@@ -28,9 +28,7 @@ using namespace mgb;
TEST(TestCompNode, Parse) { TEST(TestCompNode, Parse) {
using L = CompNode::Locator; using L = CompNode::Locator;
using D = CompNode::DeviceType; using D = CompNode::DeviceType;
auto make_lc = [](D t, int dev, int s) -> L {
return {t, dev, s};
};
auto make_lc = [](D t, int dev, int s) -> L { return {t, dev, {s}}; };


ASSERT_EQ(L::parse("xpux"), make_lc(D::UNSPEC, -1, 0)); ASSERT_EQ(L::parse("xpux"), make_lc(D::UNSPEC, -1, 0));
ASSERT_EQ(L::parse("xpux:23"), make_lc(D::UNSPEC, -1, 23)); ASSERT_EQ(L::parse("xpux:23"), make_lc(D::UNSPEC, -1, 23));
@@ -47,10 +45,9 @@ TEST(TestCompNode, Parse) {
ASSERT_EQ(L::parse("xpu23"), make_lc(D::UNSPEC, 23, 0)); ASSERT_EQ(L::parse("xpu23"), make_lc(D::UNSPEC, 23, 0));
ASSERT_EQ(L::parse("xpu23:1"), make_lc(D::UNSPEC, 23, 1)); ASSERT_EQ(L::parse("xpu23:1"), make_lc(D::UNSPEC, 23, 1));


ASSERT_EQ(L::parse("cpu:default"),
make_lc(D::CPU, L::DEVICE_CPU_DEFAULT, 0));
ASSERT_EQ(L::parse("multithread0:2"), make_lc(D::MULTITHREAD, 0, 2));
ASSERT_EQ(L::parse("multithread1:3"), make_lc(D::MULTITHREAD, 1, 3));
ASSERT_EQ(L::parse("cpu:default"), make_lc(D::CPU, L::DEVICE_CPU_DEFAULT, 0));
ASSERT_EQ(L::parse("multithread2:0"), make_lc(D::MULTITHREAD, 0, 2));
ASSERT_EQ(L::parse("multithread1:3"), make_lc(D::MULTITHREAD, 3, 1));
ASSERT_EQ(L::parse("multithread:default:2"), ASSERT_EQ(L::parse("multithread:default:2"),
make_lc(D::MULTITHREAD, L::DEVICE_MULTITHREAD_DEFAULT, 2)); make_lc(D::MULTITHREAD, L::DEVICE_MULTITHREAD_DEFAULT, 2));


@@ -65,6 +62,10 @@ TEST(TestCompNode, Parse) {
ASSERT_THROW(L::parse("heaxgon0"), MegBrainError); ASSERT_THROW(L::parse("heaxgon0"), MegBrainError);
ASSERT_THROW(L::parse("rcom0"), MegBrainError); ASSERT_THROW(L::parse("rcom0"), MegBrainError);
ASSERT_THROW(L::parse("cmabricon0"), MegBrainError); ASSERT_THROW(L::parse("cmabricon0"), MegBrainError);
ASSERT_THROW(L::parse("multithread"), MegBrainError);
ASSERT_THROW(L::parse("multithread1:"), MegBrainError);
ASSERT_THROW(L::parse("multithread1:default"), MegBrainError);
ASSERT_THROW(L::parse("multithread1:default:0"), MegBrainError);
} }


TEST(TestCompNode, SetDefaultDev) { TEST(TestCompNode, SetDefaultDev) {
@@ -107,12 +108,12 @@ TEST(TestCompNode, Load) {
#endif #endif


#if MGB_HAVE_THREAD #if MGB_HAVE_THREAD
auto cn_multi_thread0 = CompNode::load("multithread0:2");
auto cn_multi_thread1 = CompNode::load("multithread1:2");
ASSERT_EQ(CompNode::load("multithread0:2"), cn_multi_thread0);
ASSERT_EQ(CompNode::load("multithread1:2"), cn_multi_thread1);
ASSERT_NE(CompNode::load("multithread0:4"), cn_multi_thread0);
ASSERT_NE(CompNode::load("multithread1:4"), cn_multi_thread1);
auto cn_multi_thread0 = CompNode::load("multithread2:0");
auto cn_multi_thread1 = CompNode::load("multithread2:1");
ASSERT_EQ(CompNode::load("multithread2:0"), cn_multi_thread0);
ASSERT_EQ(CompNode::load("multithread2:1"), cn_multi_thread1);
ASSERT_NE(CompNode::load("multithread4:0"), cn_multi_thread0);
ASSERT_NE(CompNode::load("multithread4:1"), cn_multi_thread1);


auto cn_multi_default0 = CompNode::load("multithread:default:2"); auto cn_multi_default0 = CompNode::load("multithread:default:2");
auto cn_multi_default1 = CompNode::load("multithread:default:4"); auto cn_multi_default1 = CompNode::load("multithread:default:4");
@@ -139,7 +140,7 @@ TEST(TestCompNode, FreeAfterFinalize) {
auto type = static_cast<CompNode::DeviceType>(i); auto type = static_cast<CompNode::DeviceType>(i);
if (!CompNode::get_device_count(type)) if (!CompNode::get_device_count(type))
continue; continue;
auto cn = CompNode::load(CompNode::Locator{type});
auto cn = CompNode::load(CompNode::Locator{type, -1, {0}});
auto ptr = cn.alloc_device(123); auto ptr = cn.alloc_device(123);
CompNode::finalize(); CompNode::finalize();
cn.free_device(ptr); cn.free_device(ptr);
@@ -190,13 +191,13 @@ TEST(TestCompNodeCPU, CoreAffinity) {
size_t data0, data1 = 0; size_t data0, data1 = 0;
auto empty_task = []() {}; auto empty_task = []() {};
auto cn0 = CompNode::load("cpu:default"), cn1 = CompNode::load("cpu0"), auto cn0 = CompNode::load("cpu:default"), cn1 = CompNode::load("cpu0"),
cn2 = CompNode::load("multithread0:2");
cn2 = CompNode::load("multithread2:0");
auto binding0 = [&](size_t) { data0 = 10; }; auto binding0 = [&](size_t) { data0 = 10; };
CompNodeEnv::from_comp_node(cn0).cpu_env().set_affinity(binding0); CompNodeEnv::from_comp_node(cn0).cpu_env().set_affinity(binding0);
CompNodeEnv::from_comp_node(cn0).cpu_env().dispatch(empty_task); CompNodeEnv::from_comp_node(cn0).cpu_env().dispatch(empty_task);
cn0.sync(); cn0.sync();


auto binding1 = [&](size_t) { data1 = 20; };
auto binding1 = [&](size_t ) { data1 = 20; };
CompNodeEnv::from_comp_node(cn1).cpu_env().set_affinity(binding1); CompNodeEnv::from_comp_node(cn1).cpu_env().set_affinity(binding1);
CompNodeEnv::from_comp_node(cn1).cpu_env().dispatch(empty_task); CompNodeEnv::from_comp_node(cn1).cpu_env().dispatch(empty_task);
cn1.sync(); cn1.sync();
@@ -238,7 +239,7 @@ TEST(TestCompNode, CPU_MULTI_THREAD) {
}; };


for (auto&& str : std::vector<std::string>{ for (auto&& str : std::vector<std::string>{
"multithread0:2", "multithread0:4", "multithread:default:4"}) {
"multithread2:0", "multithread4:0", "multithread:default:4"}) {
auto cn0 = CompNode::load("cpu0"), cn1 = CompNode::load(str); auto cn0 = CompNode::load("cpu0"), cn1 = CompNode::load(str);
std::thread wk_thread0{std::ref(worker), std::ref(dst0), std::ref(cn0)}; std::thread wk_thread0{std::ref(worker), std::ref(dst0), std::ref(cn0)};
std::thread wk_thread1{std::ref(worker), std::ref(dst1), std::ref(cn1)}; std::thread wk_thread1{std::ref(worker), std::ref(dst1), std::ref(cn1)};
@@ -271,9 +272,9 @@ TEST(TestCompNodeCPU, PhysicalDispatch) {
L::set_device_map(DT, ID, 0); L::set_device_map(DT, ID, 0);
L::set_device_map(DT, ID + 1, 0); L::set_device_map(DT, ID + 1, 0);
L::set_device_map(DT, ID + 2, 1); L::set_device_map(DT, ID + 2, 1);
auto cn0 = CompNode::load({DT, ID, 0}),
cn1 = CompNode::load({DT, ID + 1, 0}),
cn2 = CompNode::load({DT, ID + 2, 0});
auto cn0 = CompNode::load({DT, ID, {0}}),
cn1 = CompNode::load({DT, ID + 1, {0}}),
cn2 = CompNode::load({DT, ID + 2, {0}});
#if MGB_HAVE_THREAD #if MGB_HAVE_THREAD
ASSERT_NE(cn0, cn1); ASSERT_NE(cn0, cn1);
#else #else
@@ -532,10 +533,10 @@ TEST(TestCompNode, MultipleLoad) {
for (size_t i = 1; i < CompNode::NR_DEVICE_TYPE; ++i) { for (size_t i = 1; i < CompNode::NR_DEVICE_TYPE; ++i) {
auto dt = static_cast<CompNode::DeviceType>(i); auto dt = static_cast<CompNode::DeviceType>(i);
if (CompNode::get_device_count(dt)) { if (CompNode::get_device_count(dt)) {
auto cn = CompNode::load({dt});
auto cn = CompNode::load({dt, 0, {0}});
mgb_log("comp node %s is available", cn.to_string().c_str()); mgb_log("comp node %s is available", cn.to_string().c_str());
run(cn); run(cn);
cn = CompNode::load({dt});
cn = CompNode::load({dt, 0, {0}});
run(cn); run(cn);
} }
} }
@@ -591,7 +592,7 @@ TYPED_TEST(TestCPUCompSeqRec, run_default_cpu) {
comp_node_test::seq_rec::run<TypeParam>(CompNode::load("cpu:default")); comp_node_test::seq_rec::run<TypeParam>(CompNode::load("cpu:default"));
} }
TYPED_TEST(TestCPUCompSeqRec, run_multi_thread) { TYPED_TEST(TestCPUCompSeqRec, run_multi_thread) {
auto cn = CompNode::load("multithread0:4");
auto cn = CompNode::load("multithread4:0");
comp_node_test::seq_rec::run<TypeParam>(cn); comp_node_test::seq_rec::run<TypeParam>(cn);
} }




Loading…
Cancel
Save