Browse Source

fix(param_pack): impl param pack concat in imperative_rt

GitOrigin-RevId: 91edd9c0bf
release-1.5
Megvii Engine Team 4 years ago
parent
commit
346d242051
3 changed files with 138 additions and 64 deletions
  1. +87
    -0
      imperative/src/impl/async_releaser.h
  2. +48
    -0
      imperative/src/impl/ops/tensor_manip.cpp
  3. +3
    -64
      imperative/src/impl/physical_tensor.cpp

+ 87
- 0
imperative/src/impl/async_releaser.h View File

@@ -0,0 +1,87 @@
/**
* \file imperative/src/impl/async_releaser.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain/comp_node.h"
#include "megbrain/imperative/blob_manager.h"
#include "megbrain/system.h"

#include "./event_pool.h"

namespace mgb {
namespace imperative {

class AsyncReleaser : public CompNodeDepedentObject {
struct WaiterParam {
CompNode cn;
CompNode::Event* event;
BlobPtr blob;
HostTensorStorage::RawStorage storage;
};
class Waiter final : public AsyncQueueSC<WaiterParam, Waiter> {
AsyncReleaser* m_par_releaser;

public:
// disable busy wait by set max_spin=0 to save CPU cycle
Waiter(AsyncReleaser* releaser)
: AsyncQueueSC<WaiterParam, Waiter>(0),
m_par_releaser(releaser) {}

void process_one_task(WaiterParam& param) {
if (param.event->finished()) {
param.blob.reset();
param.storage.reset();
EventPool::without_timer().free(param.event);
return;
}

using namespace std::literals;
std::this_thread::sleep_for(1us);
add_task(std::move(param));
}
void on_async_queue_worker_thread_start() override {
sys::set_thread_name("releaser");
}
};
Waiter m_waiter{this};

protected:
std::shared_ptr<void> on_comp_node_finalize() override {
m_waiter.wait_task_queue_empty();
return {};
}

public:
static AsyncReleaser* inst() {
static AsyncReleaser releaser;
return &releaser;
}

~AsyncReleaser() {
m_waiter.wait_task_queue_empty();
}

void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); }

void add(const HostTensorND& hv) {
add(hv.comp_node(), {}, hv.storage().raw_storage());
}

void add(CompNode cn, BlobPtr blob,
HostTensorStorage::RawStorage storage = {}) {
auto event = EventPool::without_timer().alloc(cn);
event->record();
m_waiter.add_task({cn, event, std::move(blob), std::move(storage)});
}
};
}
}

+ 48
- 0
imperative/src/impl/ops/tensor_manip.cpp View File

@@ -12,6 +12,9 @@
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"

#include "../async_releaser.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h" #include "../op_trait.h"


namespace mgb::imperative { namespace mgb::imperative {
@@ -173,6 +176,7 @@ SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(
auto&& shapes = get_shapes(param.shapes); auto&& shapes = get_shapes(param.shapes);
size_t dtype_size = inputs[0]->layout().dtype.size(); size_t dtype_size = inputs[0]->layout().dtype.size();
for (size_t i = 0; i < shapes.size(); ++i) { for (size_t i = 0; i < shapes.size(); ++i) {
// memory forward
ret.push_back( ret.push_back(
inputs[0]->sub(param.offsets[i * 2] * dtype_size, shapes[i])); inputs[0]->sub(param.offsets[i * 2] * dtype_size, shapes[i]));
} }
@@ -197,8 +201,52 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
return opr; return opr;
} }


SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
def.cast_final_safe<ParamPackConcat>();
mgb_assert(inputs.size() > 1, "param_pack should have at least one input");
auto comp_node = inputs.front()->comp_node();
auto dtype = inputs.front()->dtype();
size_t nr_inputs = inputs.size() - 1;
size_t nr_elems = 0;
for (size_t i = 0; i < nr_inputs; ++i) {
auto& input = inputs[i];
mgb_assert(comp_node == input->comp_node(), "inputs for param_pack_concat must in same comp_node");
mgb_assert(dtype == input->dtype(), "inputs for param_pack_concat must have same dtype");
nr_elems += input->layout().total_nr_elems();
}
auto dest_layout = TensorLayout({nr_elems}, dtype);
auto output = Tensor::make(dest_layout, comp_node);
auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node);
size_t srcs_size = sizeof(void*)*nr_inputs;
void** srcs_raw_ptr = (void**)comp_node.alloc_host(srcs_size);
std::shared_ptr<dt_byte> srcs_ptr = {(dt_byte*)srcs_raw_ptr, [comp_node](dt_byte* ptr){
comp_node.free_host(ptr);
}};
TensorLayout srcs_layout = TensorLayout{{nr_inputs}, dtype::Int32()};
size_t ws_size;
{
TensorShapeArray src_shapes;
for (size_t i = 0; i < nr_inputs; ++i) {
src_shapes.push_back(inputs[i]->shape());
}
ws_size = caller.op->get_workspace_in_bytes(src_shapes, inputs.back()->shape(), TensorShape{});
}
for (size_t i = 0; i < nr_inputs; ++i) {
srcs_raw_ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr;
}
HostTensorStorage srcs_storage;
srcs_storage.reset(comp_node, srcs_size, srcs_ptr);
caller.op->exec({srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(), output->dev_tensor().as_megdnn(),
caller.create_workspace({{ws_size}, dtype::Byte()}));
AsyncReleaser::inst()->add(HostTensorND{comp_node, srcs_layout}.storage(srcs_storage));
return { output };
}

OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
.apply_on_var_node(param_pack_concat_apply_on_var_node) .apply_on_var_node(param_pack_concat_apply_on_var_node)
.apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor)
.fallback(); .fallback();
} // param_pack } // param_pack




+ 3
- 64
imperative/src/impl/physical_tensor.cpp View File

@@ -11,7 +11,10 @@


#include "megbrain/imperative.h" #include "megbrain/imperative.h"
#include "megbrain/imperative/blob_manager.h" #include "megbrain/imperative/blob_manager.h"

#include "./event_pool.h" #include "./event_pool.h"
#include "./async_releaser.h"

#include <mutex> #include <mutex>


namespace mgb { namespace mgb {
@@ -19,70 +22,6 @@ namespace imperative {


namespace { namespace {


class AsyncReleaser : public CompNodeDepedentObject {
struct WaiterParam {
CompNode cn;
CompNode::Event* event;
BlobPtr blob;
HostTensorStorage::RawStorage storage;
};
class Waiter final : public AsyncQueueSC<WaiterParam, Waiter> {
AsyncReleaser* m_par_releaser;

public:
// disable busy wait by set max_spin=0 to save CPU cycle
Waiter(AsyncReleaser* releaser)
: AsyncQueueSC<WaiterParam, Waiter>(0),
m_par_releaser(releaser) {}

void process_one_task(WaiterParam& param) {
if (param.event->finished()) {
param.blob.reset();
param.storage.reset();
EventPool::without_timer().free(param.event);
return;
}

using namespace std::literals;
std::this_thread::sleep_for(1us);
add_task(std::move(param));
}
void on_async_queue_worker_thread_start() override {
sys::set_thread_name("releaser");
}
};
Waiter m_waiter{this};

protected:
std::shared_ptr<void> on_comp_node_finalize() override {
m_waiter.wait_task_queue_empty();
return {};
}

public:
static AsyncReleaser* inst() {
static AsyncReleaser releaser;
return &releaser;
}

~AsyncReleaser() {
m_waiter.wait_task_queue_empty();
}

void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); }

void add(const HostTensorND& hv) {
add(hv.comp_node(), {}, hv.storage().raw_storage());
}

void add(CompNode cn, BlobPtr blob,
HostTensorStorage::RawStorage storage = {}) {
auto event = EventPool::without_timer().alloc(cn);
event->record();
m_waiter.add_task({cn, event, std::move(blob), std::move(storage)});
}
};

class CompNodeSyncManager : public CompNodeDepedentObject { class CompNodeSyncManager : public CompNodeDepedentObject {
ThinHashMap<Blob*, std::unique_ptr<CompNode::Event>> m_blob2event; ThinHashMap<Blob*, std::unique_ptr<CompNode::Event>> m_blob2event;
std::mutex m_mtx; std::mutex m_mtx;


Loading…
Cancel
Save