Browse Source

refactor(mgb): add TensorND::proxy_to_default_cpu

GitOrigin-RevId: 3ab8525f1c
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
5e7d2a91c2
2 changed files with 42 additions and 7 deletions
  1. +33
    -7
      src/core/include/megbrain/tensor.h
  2. +9
    -0
      src/core/test/tensor.cpp

+ 33
- 7
src/core/include/megbrain/tensor.h View File

@@ -101,6 +101,14 @@ class Slice {
SubTensorSpec apply(TensorLayout layout, int axis) const;
};

template <class Trait> class TensorStorage;

class DeviceTensorStorageTrait;
class HostTensorStorageTrait;

using HostTensorStorage = TensorStorage<HostTensorStorageTrait>;
using DeviceTensorStorage = TensorStorage<DeviceTensorStorageTrait>;

/*!
* \brief manager for raw tensor memory
*
@@ -230,6 +238,18 @@ class TensorStorage {
std::enable_if<!std::is_same<Trait, RTrait>::value>::type>
static TensorStorage make_proxy(const TensorStorage<RTrait> &src);

/*!
* \brief make a DeviceTensorStorage on default_cpu
* that shares memory with this
*
* this must be a HostTensorStorage. Alignment not checked.
*/
template<bool x = true, typename = std::enable_if_t<x && std::is_same<Trait, HostTensorStorageTrait>::value>>
DeviceTensorStorage proxy_to_default_cpu() const {
ptr();
return {true, CompNode::default_cpu(), m_size, m_capacity, m_offset, m_data};
}

//! shortcut for raw_storage().use_count(), but won't trigger lazy alloc
size_t use_count() const {
if (m_size > m_capacity) {
@@ -284,11 +304,12 @@ class TensorStorage {

[[noreturn]] static void on_invalid_comp_node();
};
class DeviceTensorStorageTrait;
class HostTensorStorageTrait;

using HostTensorStorage = TensorStorage<HostTensorStorageTrait>;
using DeviceTensorStorage = TensorStorage<DeviceTensorStorageTrait>;

template<class TensorStorage> class TensorND;

using HostTensorND = TensorND<HostTensorStorage>;
using DeviceTensorND = TensorND<DeviceTensorStorage>;

/*!
* \brief n-dimensional tensor
@@ -519,10 +540,15 @@ class TensorND {
ret.reset(TensorStorage::make_proxy(src.storage()), src.layout());
return ret;
}
};

using HostTensorND = TensorND<HostTensorStorage>;
using DeviceTensorND = TensorND<DeviceTensorStorage>;
//! similar to HostTensorStorage::proxy_to_default_cpu
template<bool x = true, typename = std::enable_if_t<x && std::is_same<TensorStorage, HostTensorStorage>::value>>
DeviceTensorND proxy_to_default_cpu() const {
DeviceTensorND ret;
ret.reset(storage().proxy_to_default_cpu(), layout());
return ret;
}
};

/*!
* \brief call memset in the data of a device tensor


+ 9
- 0
src/core/test/tensor.cpp View File

@@ -418,4 +418,13 @@ TEST(TestTensor, CpuCudaD2DCopy) {
}
}

TEST(TestTensor, ProxyToDefaultCPU) {
auto cn = CompNode::load("xpux");
auto x = HostTensorND(cn, TensorLayout({1, 2, 3}, dtype::Float32{}));
auto y = x.proxy_to_default_cpu();
ASSERT_EQ(y.comp_node(), CompNode::default_cpu());
ASSERT_EQ(x.layout(), y.layout());
ASSERT_EQ(x.raw_ptr(), y.raw_ptr());
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save