@@ -0,0 +1,33 @@ | |||||
from ..core._imperative_rt import TensorSanityCheckImpl | |||||
from ..core._imperative_rt.imperative import sync | |||||
class TensorSanityCheck: | |||||
r"""An object that checks whether the input tensors of each operator have changed before and after the operation. | |||||
Examples: | |||||
.. testcode:: | |||||
from megengine import tensor | |||||
from megengine.utils.tensor_sanity_check import TensorSanityCheck | |||||
with TensorSanityCheck() as checker: | |||||
a = tensor([1, 2]) | |||||
b = tensor([3, 4]) | |||||
c = a + b | |||||
print(c.numpy()) | |||||
.. testoutput:: | |||||
[4 6] | |||||
""" | |||||
def __init__(self): | |||||
self.impl = TensorSanityCheckImpl() | |||||
def __enter__(self): | |||||
sync() | |||||
self.impl.enable() | |||||
return self | |||||
def __exit__(self, val, type, trace): | |||||
sync() | |||||
self.impl.disable() |
@@ -23,6 +23,7 @@ | |||||
#include "megbrain/comp_node.h" | #include "megbrain/comp_node.h" | ||||
#include "megbrain/imperative/blob_manager.h" | #include "megbrain/imperative/blob_manager.h" | ||||
#include "megbrain/imperative/profiler.h" | #include "megbrain/imperative/profiler.h" | ||||
#include "megbrain/imperative/tensor_sanity_check.h" | |||||
#include "megbrain/serialization/helper.h" | #include "megbrain/serialization/helper.h" | ||||
#if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
@@ -225,6 +226,19 @@ void init_utils(py::module m) { | |||||
}, | }, | ||||
py::arg("path") = std::optional<std::string>()); | py::arg("path") = std::optional<std::string>()); | ||||
using mgb::imperative::TensorSanityCheck; | |||||
py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl") | |||||
.def(py::init<>()) | |||||
.def("enable", | |||||
[](TensorSanityCheck& checker) -> TensorSanityCheck& { | |||||
checker.enable(); | |||||
return checker; | |||||
}) | |||||
.def("disable", | |||||
[](TensorSanityCheck& checker) { | |||||
checker.disable(); | |||||
}); | |||||
#if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"), | m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"), | ||||
py::arg("port") = 0); | py::arg("port") = 0); | ||||
@@ -110,9 +110,20 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
return out; | return out; | ||||
} | } | ||||
SmallVector<LogicalTensorDesc> infer_output_attrs( | |||||
const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs) { | |||||
SmallVector<LogicalTensorDesc> out; | |||||
for (size_t i = 0; i < 2; ++ i) { | |||||
out.push_back({TensorLayout(), inputs[0]->comp_node()}); | |||||
} | |||||
return out; | |||||
} | |||||
OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) | OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.apply_on_physical_tensor(apply_on_physical_tensor) | .apply_on_physical_tensor(apply_on_physical_tensor) | ||||
.infer_output_attrs(infer_output_attrs) | |||||
.fallback(); | .fallback(); | ||||
} // namespace | } // namespace | ||||
@@ -0,0 +1,130 @@ | |||||
/** | |||||
* \file src/core/impl/imperative/tensor_sanity_check.cpp | |||||
* | |||||
* This file is part of MegBrain, a deep learning framework developed by Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. | |||||
* | |||||
*/ | |||||
#include "megbrain/imperative/tensor_sanity_check.h" | |||||
#include "./op_trait.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
TensorChecksumCalc::ChecksumResult TensorChecksumCalc::calc(TensorPtr ptr) { | |||||
auto&& dt = ptr->dev_tensor(); | |||||
if (!dt.layout().total_nr_elems()) { | |||||
static ChecksumResult empty_checksum; | |||||
return empty_checksum; | |||||
} | |||||
auto span = dt.layout().span(); | |||||
megdnn::TensorND tensor; | |||||
tensor.raw_ptr = dt.raw_ptr() + span.low_byte; | |||||
tensor.layout.init_contiguous_stride({span.dist_byte()}); | |||||
tensor.layout.dtype = dtype::Byte(); | |||||
DeviceTensorStorage* workspace; | |||||
{ | |||||
MGB_LOCK_GUARD(m_workspace_mtx); | |||||
workspace = &m_workspace[std::this_thread::get_id()] | |||||
.storage[ptr->comp_node()]; | |||||
} | |||||
auto comp_node = ptr->comp_node(); | |||||
comp_node.activate(); | |||||
auto opr = opr::intl::get_megdnn_global_opr<megdnn::Checksum>(comp_node); | |||||
auto workspace_reqsize = opr->get_workspace_in_bytes(tensor.layout); | |||||
workspace->comp_node(ptr->comp_node()).ensure_size(workspace_reqsize); | |||||
megdnn::Workspace mwk; | |||||
if (workspace_reqsize) | |||||
mwk = {workspace->ptr(), workspace_reqsize}; | |||||
return opr->exec(tensor, mwk); | |||||
} | |||||
class TensorSanityCheckImpl { | |||||
public: | |||||
std::vector<std::tuple<OpTrait*, std::unique_ptr<ApplyOnPhysicalTensor>>> | |||||
hook_list; | |||||
std::unordered_map<TensorPtr, TensorChecksumCalc::ChecksumResult> | |||||
tensor2chksum; // TODO: may increase device memory overhead | |||||
TensorSanityCheckImpl() { | |||||
m_calc = std::make_unique<TensorChecksumCalc>(); | |||||
} | |||||
bool check(TensorPtr p); | |||||
private: | |||||
std::unique_ptr<TensorChecksumCalc> m_calc; | |||||
}; | |||||
bool TensorSanityCheckImpl::check(TensorPtr p) { | |||||
auto&& it = tensor2chksum.find(p); | |||||
auto&& chksum = m_calc->calc(p); | |||||
if (it == tensor2chksum.end()) { | |||||
tensor2chksum[p] = chksum; | |||||
return true; | |||||
} | |||||
return it->second == chksum; | |||||
} | |||||
void TensorSanityCheck::enable() { | |||||
CompNode::sync_all(); | |||||
OpTrait::for_each_trait([this](OpTrait& trait) { | |||||
auto backup = std::make_unique<ApplyOnPhysicalTensor>( | |||||
std::move(trait.apply_on_physical_tensor)); | |||||
trait.apply_on_physical_tensor = [this, backup = backup.get()] ( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
for (auto&& i: inputs) { | |||||
if (!m_checker->check(i)) { | |||||
mgb_throw(TensorChecksumCalc::Error, | |||||
"tensor modified before exec %s", print_op(def).c_str()); | |||||
} | |||||
} | |||||
auto output = (*backup)(def, inputs); | |||||
for (auto&& i: output) { | |||||
mgb_assert(m_checker->check(i)); | |||||
} | |||||
for (auto&& i: inputs) { | |||||
if (!m_checker->check(i)) { | |||||
mgb_throw(TensorChecksumCalc::Error, | |||||
"tensor modified after exec %s", print_op(def).c_str()); | |||||
} | |||||
} | |||||
return output; | |||||
}; | |||||
m_checker->hook_list.push_back({&trait, std::move(backup)}); | |||||
}); | |||||
} | |||||
void TensorSanityCheck::disable() { | |||||
for (auto&& hook : m_checker->hook_list) { | |||||
std::get<0>(hook)->apply_on_physical_tensor = | |||||
std::move(*std::get<1>(hook)); | |||||
} | |||||
m_checker->tensor2chksum.clear(); | |||||
m_checker->hook_list.clear(); | |||||
} | |||||
TensorSanityCheck::TensorSanityCheck() { | |||||
m_checker = std::make_unique<TensorSanityCheckImpl>(); | |||||
} | |||||
TensorSanityCheck::~TensorSanityCheck () { | |||||
} | |||||
std::string TensorSanityCheck::print_op(const OpDef& def){ | |||||
auto* opr_attr = def.try_cast_final<const OprAttr>(); | |||||
if(opr_attr){ | |||||
return std::string("OprAttr:") + opr_attr->type; | |||||
} | |||||
return def.dyn_typeinfo()->name; | |||||
} | |||||
} // namespace imperative | |||||
} // namespace mgb |
@@ -0,0 +1,50 @@ | |||||
/** | |||||
* \file src/core/include/megbrain/tensor_sanity_check.h | |||||
* | |||||
* This file is part of MegBrain, a deep learning framework developed by Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. | |||||
* | |||||
*/ | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/imperative/ops/opr_attr.h" | |||||
#include "megbrain/imperative/op_def.h" | |||||
#include "megbrain/plugin/var_sanity_check.h" | |||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
#include "megdnn/oprs/general.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
class TensorChecksumCalc { | |||||
public: | |||||
using ChecksumResult = megdnn::opr_result::Checksum; | |||||
using Error = VarSanityCheckError; | |||||
struct WorkspaceCache { | |||||
//! var comp node to workspace | |||||
CompNode::UnorderedMap<DeviceTensorStorage> storage; | |||||
}; | |||||
ThinHashMap<std::thread::id, WorkspaceCache> m_workspace; | |||||
std::mutex m_workspace_mtx; | |||||
ChecksumResult calc(TensorPtr ptr); | |||||
TensorChecksumCalc() {} | |||||
}; | |||||
class TensorSanityCheckImpl; | |||||
class TensorSanityCheck { | |||||
public: | |||||
TensorSanityCheck(); | |||||
~TensorSanityCheck(); | |||||
void enable(); | |||||
void disable(); | |||||
std::string print_op(const OpDef& def); | |||||
private: | |||||
std::unique_ptr<TensorSanityCheckImpl> m_checker; | |||||
}; | |||||
} // namespace imperative | |||||
} // namespace mgb |