Browse Source

feat(opr/nvof): add nvof operator

This reverts commit 18b84072ac.

GitOrigin-RevId: 3b7622784d
release-1.1
Megvii Engine Team 4 years ago
parent
commit
95f6b53183
17 changed files with 2276 additions and 3 deletions
  1. +29
    -0
      imperative/python/megengine/functional/nn.py
  2. +1
    -1
      src/CMakeLists.txt
  3. +85
    -0
      src/opr/impl/misc.cpp
  4. +3
    -0
      src/opr/impl/misc.oprdecl
  5. +3
    -0
      src/opr/impl/misc.sereg.h
  6. +248
    -0
      src/opr/impl/nvof/NvOF.cpp
  7. +241
    -0
      src/opr/impl/nvof/NvOF.h
  8. +343
    -0
      src/opr/impl/nvof/NvOFCuda.cpp
  9. +178
    -0
      src/opr/impl/nvof/NvOFCuda.h
  10. +55
    -0
      src/opr/impl/nvof/NvOFDefines.h
  11. +202
    -0
      src/opr/impl/nvof/denseflownvidia.cpp
  12. +80
    -0
      src/opr/impl/nvof/denseflownvidia.h
  13. +510
    -0
      src/opr/impl/nvof/nvOpticalFlowCommon.h
  14. +258
    -0
      src/opr/impl/nvof/nvOpticalFlowCuda.h
  15. +38
    -0
      src/opr/include/megbrain/opr/misc.h
  16. +1
    -2
      src/serialization/impl/schema.fbs
  17. +1
    -0
      tools/param_defs/mgb_opr_param_defs.py

+ 29
- 0
imperative/python/megengine/functional/nn.py View File

@@ -1536,6 +1536,35 @@ def nms(
return keep_inds


def nvof(src: Tensor, precision: int = 1) -> Tensor:
r"""
Implements NVIDIA Optical Flow SDK.

:src shape: input tensor with shape (n, t, h, w, c4).
:src dtype: uint8.
:param precision: 0:NV_OF_PERF_LEVEL_SLOW 1:NV_OF_PERF_LEVEL_MEDIUM 2:NV_OF_PERF_LEVEL_FAST.
:output shape: (n, t-1, h//4, w//4, c2).
:output dtype: int16.

.. code-block:: python

import numpy as np
from megengine import tensor
import megengine.functional as F

x = np.random.random_integers(0, 255, (1,2,224,244,4)).astype("uint8")
src = tensor(x)
result = F.nn.nvof(src, precision=1)
print(result.numpy())

"""
assert isinstance(src, (Tensor, megbrain_graph.VarNode)), "src must be Tensor type"
assert src.ndim == 5 and src.shape[4] == 4

src = src.detach()

op = builtin.NvOf(precision=precision)
return apply(op, src)[0]


from .loss import * # isort:skip


+ 1
- 1
src/CMakeLists.txt View File

@@ -2,7 +2,7 @@ if(MGE_WITH_JIT_MLIR)
add_subdirectory(jit/impl/mlir/ir)
endif()

file(GLOB_RECURSE SOURCES core/impl/*.cpp gopt/impl/*.cpp opr/impl/*.cpp plugin/impl/*.cpp serialization/impl/*.cpp core/impl/*.inl gopt/impl/*.inl opr/impl/*.inl plugin/impl/*.inl serialization/impl/*.inl)
file(GLOB_RECURSE SOURCES core/impl/*.cpp gopt/impl/*.cpp opr/impl/*.cpp opr/impl/nvof/*.cpp plugin/impl/*.cpp serialization/impl/*.cpp core/impl/*.inl gopt/impl/*.inl opr/impl/*.inl plugin/impl/*.inl serialization/impl/*.inl)

if(MGE_WITH_JIT)
file(GLOB_RECURSE SOURCES_ jit/impl/*.cpp jit/impl/*.inl)


+ 85
- 0
src/opr/impl/misc.cpp View File

@@ -159,6 +159,91 @@ void Cumsum::init_output_static_infer_desc() {
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace});
}

/* ================= NvOf ================= */

#if MGB_CUDA
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NvOf);

NvOf::NvOf(VarNode* opr, const Param& param, const OperatorNodeConfig& config)
: Super{opr->owner_graph(), config, "NvOf", {opr}}, m_param{param} {
constexpr size_t NDIM = 5;
mgb_assert(opr->dtype() == dtype::Uint8());
add_input({opr});
//! NvOf hava only one output
add_output(None);

mgb_log_debug("init nvof engine with precision: %u", m_param.precision);
auto input_shape = this->input()[0]->shape();

//! nvof input format: nthwc4
mgb_assert(input_shape.ndim == NDIM);
//! now only support RGBA format channel data
mgb_assert(input_shape[4] == 4);

for (size_t i = 0; i < NDIM; i++) {
vshape.push_back(input_shape[i]);
}
}

void NvOf::init_output_dtype() {
output(0)->dtype(dtype::Int16());
}

SymbolVar NvOf::make(SymbolVar opr, const Param& param,
const OperatorNodeConfig& config) {
return opr.insert_single_output_opr<NvOf>(opr.node(), param, config);
}

void NvOf::scn_do_execute() {
auto c = this->comp_node();
//! comp_node may init on CUDA or CPU, eg: lar with --cpu
//! if ON CUDA, need sync, caused by we use different stream
if (CompNode::DeviceType::CUDA == c.device_type()) {
c.sync();
} else {
mgb_log_warn(
"NvOf opr on non CUDA comp_node, which will triger H2D and "
"D2H!!");
}

//! create NvOF engine at same device id of comp_node, can not get
//! comp_node device id, when NvOf:NvOf, so init at scn_do_execute
std::lock_guard<std::mutex> lock(m_lock);
if (init_flag == false) {
//! nvof sdk do not imp p2p copy, so init nvof engine on the same
//! device with mgb comp_node
nv_flow_extractor = std::make_shared<NVFlowExtractor>(
c.locator().device, vshape, m_param.precision, true, true);
init_flag = true;
}

nv_flow_extractor->extract_flow(
static_cast<unsigned char*>(
input(0)->dev_tensor().as_megdnn().raw_ptr),
vshape,
reinterpret_cast<int16_t*>(
output(0)->dev_tensor().as_megdnn().raw_ptr));
}

void NvOf::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto infer_shape = [](TensorShape& dest, const InpVal& iv) {
auto ishp = iv.val.at(0).shape();
SmallVector<size_t> tv;
tv.push_back(ishp[0]);
tv.push_back(ishp[1] - 1);
tv.push_back(ishp[2] / 4);
tv.push_back(ishp[3] / 4);
tv.push_back(ishp[4] / 2);
dest = tv;

return true;
};
owner_graph()->static_infer_manager().register_shape_infer(
output(0),
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape});
}
#endif

/* ================= CondTake ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake);


+ 3
- 0
src/opr/impl/misc.oprdecl View File

@@ -63,5 +63,8 @@ decl_opr('TopK',
inputs=['data', 'k'], params='TopK',
desc='Select the top k values from sorted result.')

decl_opr('NvOf',
inputs=['src'], params='NvOf',
desc='opr Implements NVIDIA Optical Flow SDK.')

# vim: ft=python

+ 3
- 0
src/opr/impl/misc.sereg.h View File

@@ -70,6 +70,9 @@ namespace opr {
using CumsumV1 = opr::Cumsum;
MGB_SEREG_OPR(CumsumV1, 1);

#if MGB_CUDA
MGB_SEREG_OPR(NvOf, 1);
#endif

} // namespace opr
} // namespace mgb


+ 248
- 0
src/opr/impl/nvof/NvOF.cpp View File

@@ -0,0 +1,248 @@
/*
* Copyright 2018-2019 NVIDIA Corporation. All rights reserved.
*
* Please refer to the NVIDIA end user license agreement (EULA) associated
* with this source code for terms and conditions that govern your use of
* this software. Any use, reproduction, disclosure, or distribution of
* this software and related documentation outside the terms of the EULA
* is strictly prohibited.
*
*/

/**
* \file src/opr/impl/nvof/NvOF.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"

#if MGB_CUDA
#ifdef _WIN32
#include <Windows.h>
#else
#include <dlfcn.h>
#endif

#include "NvOF.h"

NvOF::NvOF(uint32_t nWidth, uint32_t nHeight, NV_OF_BUFFER_FORMAT eInBufFmt, NV_OF_MODE eMode,
NV_OF_PERF_LEVEL preset) :
m_nOutGridSize(NV_OF_OUTPUT_VECTOR_GRID_SIZE_MAX),
m_ePreset(preset),
m_ofMode(eMode)
{
m_inputElementSize = 1;
if (eInBufFmt == NV_OF_BUFFER_FORMAT_ABGR8)
m_inputElementSize = 4;


memset(&m_inputBufferDesc, 0, sizeof(m_inputBufferDesc));
m_inputBufferDesc.width = nWidth;
m_inputBufferDesc.height = nHeight;
m_inputBufferDesc.bufferFormat = eInBufFmt;
m_inputBufferDesc.bufferUsage = NV_OF_BUFFER_USAGE_INPUT;

}

bool NvOF::CheckGridSize(uint32_t nOutGridSize)
{
uint32_t size;
DoGetOutputGridSizes(nullptr, &size);

std::unique_ptr<uint32_t[]> val(new uint32_t[size]);
DoGetOutputGridSizes(val.get(), &size);

for (uint32_t i = 0; i < size; i++)
{
if (nOutGridSize == val[i])
{
return true;
}
}
return false;
}

bool NvOF::GetNextMinGridSize(uint32_t nOutGridSize, uint32_t& nextMinOutGridSize)
{
uint32_t size;
DoGetOutputGridSizes(nullptr, &size);

std::unique_ptr<uint32_t[]> val(new uint32_t[size]);
DoGetOutputGridSizes(val.get(), &size);

nextMinOutGridSize = NV_OF_OUTPUT_VECTOR_GRID_SIZE_MAX;
for (uint32_t i = 0; i < size; i++)
{
if (nOutGridSize == val[i])
{
nextMinOutGridSize = nOutGridSize;
return true;
}
if (nOutGridSize < val[i] && val[i] < nextMinOutGridSize)
{
nextMinOutGridSize = val[i];
}
}
return (nextMinOutGridSize >= NV_OF_OUTPUT_VECTOR_GRID_SIZE_MAX) ? false : true;
}

void NvOF::Init(uint32_t nOutGridSize)
{
m_nOutGridSize = nOutGridSize;

auto nOutWidth = (m_inputBufferDesc.width + m_nOutGridSize - 1) / m_nOutGridSize;
auto nOutHeight = (m_inputBufferDesc.height + m_nOutGridSize - 1) / m_nOutGridSize;

auto outBufFmt = NV_OF_BUFFER_FORMAT_SHORT2;
if (m_ofMode == NV_OF_MODE_OPTICALFLOW)
{
outBufFmt = NV_OF_BUFFER_FORMAT_SHORT2;
m_outputElementSize = sizeof(NV_OF_FLOW_VECTOR);
}
else if (m_ofMode == NV_OF_MODE_STEREODISPARITY)
{
outBufFmt = NV_OF_BUFFER_FORMAT_SHORT;
m_outputElementSize = sizeof(NV_OF_STEREO_DISPARITY);
}
else
{
mgb_throw(MegBrainError, "NVOF: Unsupported OF mode err type: NV_OF_ERR_INVALID_PARAM");
}

memset(&m_outputBufferDesc, 0, sizeof(m_outputBufferDesc));
m_outputBufferDesc.width = nOutWidth;
m_outputBufferDesc.height = nOutHeight;
m_outputBufferDesc.bufferFormat = outBufFmt;
m_outputBufferDesc.bufferUsage = NV_OF_BUFFER_USAGE_OUTPUT;

memset(&m_costBufferDesc, 0, sizeof(m_costBufferDesc));
m_costBufferDesc.width = nOutWidth;
m_costBufferDesc.height = nOutHeight;
m_costBufferDesc.bufferFormat = NV_OF_BUFFER_FORMAT_UINT;
m_costBufferDesc.bufferUsage = NV_OF_BUFFER_USAGE_COST;
m_costBufElementSize = sizeof(uint32_t);

memset(&m_hintBufferDesc, 0, sizeof(m_hintBufferDesc));
m_hintBufferDesc.width = nOutWidth;
m_hintBufferDesc.height = nOutHeight;
m_hintBufferDesc.bufferFormat = outBufFmt;
m_hintBufferDesc.bufferUsage = NV_OF_BUFFER_USAGE_HINT;
m_hintBufElementSize = m_outputElementSize;

memset(&m_initParams, 0, sizeof(m_initParams));
m_initParams.width = m_inputBufferDesc.width;
m_initParams.height = m_inputBufferDesc.height;
m_initParams.enableExternalHints = NV_OF_FALSE;
m_initParams.enableOutputCost = NV_OF_FALSE;
m_initParams.hintGridSize = NV_OF_HINT_VECTOR_GRID_SIZE_UNDEFINED;
m_initParams.outGridSize = (NV_OF_OUTPUT_VECTOR_GRID_SIZE)m_nOutGridSize;
m_initParams.mode = m_ofMode;
m_initParams.perfLevel = m_ePreset;
DoInit(m_initParams);
}

void NvOF::Execute(NvOFBuffer* image1,
NvOFBuffer* image2,
NvOFBuffer* outputBuffer,
NvOFBuffer* hintBuffer,
NvOFBuffer* costBuffer)
{
NV_OF_EXECUTE_INPUT_PARAMS exeInParams;
NV_OF_EXECUTE_OUTPUT_PARAMS exeOutParams;

memset(&exeInParams, 0, sizeof(exeInParams));
exeInParams.inputFrame = image1->getOFBufferHandle();
exeInParams.referenceFrame = image2->getOFBufferHandle();
exeInParams.disableTemporalHints = NV_OF_FALSE;
exeInParams.externalHints = m_initParams.enableExternalHints == NV_OF_TRUE ? hintBuffer->getOFBufferHandle() : nullptr;

memset(&exeOutParams, 0, sizeof(exeOutParams));
exeOutParams.outputBuffer = outputBuffer->getOFBufferHandle();
exeOutParams.outputCostBuffer = m_initParams.enableOutputCost == NV_OF_TRUE ? costBuffer->getOFBufferHandle() : nullptr;
DoExecute(exeInParams, exeOutParams);
}


std::vector<std::unique_ptr<NvOFBuffer>>
NvOF::CreateBuffers(NV_OF_BUFFER_USAGE usage, uint32_t numBuffers)
{
std::vector<std::unique_ptr<NvOFBuffer>> ofBuffers;

if (usage == NV_OF_BUFFER_USAGE_INPUT)
{
ofBuffers = DoAllocBuffers(m_inputBufferDesc, m_inputElementSize, numBuffers);
}
else if (usage == NV_OF_BUFFER_USAGE_OUTPUT)
{
ofBuffers = DoAllocBuffers(m_outputBufferDesc, m_outputElementSize, numBuffers);
}
else if (usage == NV_OF_BUFFER_USAGE_COST)
{
ofBuffers = DoAllocBuffers(m_costBufferDesc, m_costBufElementSize, numBuffers);
}
else if (usage == NV_OF_BUFFER_USAGE_HINT)
{
ofBuffers = DoAllocBuffers(m_hintBufferDesc, m_hintBufElementSize, numBuffers);
}
else
{
mgb_throw(MegBrainError, "NVOF: Invalid parameter err type: NV_OF_ERR_GENERIC");
}

return ofBuffers;
}

std::vector<std::unique_ptr<NvOFBuffer>>
NvOF::CreateBuffers(uint32_t nWidth, uint32_t nHeight, NV_OF_BUFFER_USAGE usage, uint32_t numBuffers)
{
std::vector<std::unique_ptr<NvOFBuffer>> ofBuffers;

NV_OF_BUFFER_DESCRIPTOR bufferDesc;

if (usage == NV_OF_BUFFER_USAGE_OUTPUT)
{
bufferDesc.width = nWidth;
bufferDesc.height = nHeight;
bufferDesc.bufferFormat = m_outputBufferDesc.bufferFormat;
bufferDesc.bufferUsage = NV_OF_BUFFER_USAGE_OUTPUT;

ofBuffers = DoAllocBuffers(bufferDesc, m_outputElementSize, numBuffers);
}
else
{
mgb_throw(MegBrainError, "NVOF: Invalid parameter err type: NV_OF_ERR_GENERIC");
}

return ofBuffers;
}

void NvOFAPI::LoadNvOFAPI()
{
#if defined(_WIN32)
#if defined(_WIN64)
HMODULE hModule = LoadLibrary(TEXT("nvofapi64.dll"));
#else
HMODULE hModule = LoadLibrary(TEXT("nvofapi.dll"));
#endif
#else
void *hModule = dlopen("libnvidia-opticalflow.so.1", RTLD_LAZY);
#endif
if (hModule == NULL)
{
mgb_throw(
MegBrainError,
"NVOF: NVOF library file not found. Please ensure that the "
"NVIDIA driver is installed type: NV_OF_ERR_OF_NOT_AVAILABLE");
}

m_hModule = hModule;
}

#endif

+ 241
- 0
src/opr/impl/nvof/NvOF.h View File

@@ -0,0 +1,241 @@
/*
* Copyright 2018-2019 NVIDIA Corporation. All rights reserved.
*
* Please refer to the NVIDIA end user license agreement (EULA) associated
* with this source code for terms and conditions that govern your use of
* this software. Any use, reproduction, disclosure, or distribution of
* this software and related documentation outside the terms of the EULA
* is strictly prohibited.
*
*/

/**
* \file src/opr/impl/nvof/NvOF.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include <cuda.h>
#include "megbrain_build_config.h"

#if MGB_CUDA
#pragma once
#include <stdint.h>
#include <string.h>
#include <iostream>
#include <memory>
#include <mutex>
#include <sstream>
#include <string>
#include <vector>
#include "NvOFDefines.h"
#include "megbrain/common.h"
#include "megbrain/exception.h"
#include "nvOpticalFlowCommon.h"

using namespace mgb;
/**
* @brief Exception class for error reporting from NvOFAPI calls
*/
class NvOFException : public std::exception {
public:
NvOFException(const std::string& errorStr, const NV_OF_STATUS errorCode)
: m_errorString(errorStr), m_errorCode(errorCode) {}
virtual ~NvOFException() throw() {}
virtual const char* what() const throw() { return m_errorString.c_str(); }
NV_OF_STATUS getErrorCode() const { return m_errorCode; }
const std::string& getErrorString() const { return m_errorString; }

private:
std::string m_errorString;
NV_OF_STATUS m_errorCode;
};

#define NVOF_API_CALL(nvOFAPI) \
do { \
NV_OF_STATUS errorCode = nvOFAPI; \
if (errorCode != NV_OF_SUCCESS) { \
std::ostringstream errorLog; \
errorLog << #nvOFAPI << "returned error " << errorCode; \
std::cout << "Exception: " << __FILE__ << ":" << __LINE__ << ":" \
<< errorLog.str() << std::endl; \
mgb_throw(MegBrainError, "NVOF_API_CALL ERROR"); \
} \
} while (0)

/*
* NvOFBuffer is a wrapper over the NvOFGPUBufferHandle object defined in
* NVOF API and provides methods for various operations associated with the
* GPU buffer.
*/
class NvOFBuffer {
public:
virtual ~NvOFBuffer() {}
uint32_t getWidth() { return m_width; }
uint32_t getHeight() { return m_height; }
uint32_t getElementSize() { return m_elementSize; }
NV_OF_BUFFER_FORMAT getBufferFormat() { return m_eBufFmt; }
NV_OF_BUFFER_USAGE getBufferUsage() { return m_eBufUsage; }

/*
* Uploads data from the host buffer specified in 'pData' to the GPU buffer.
*/
virtual void UploadData(const void* pData, CUmemorytype mem_type) = 0;

/*
* Download data to the host buffer specified in 'pData' from the GPU
* buffer.
*/
virtual void DownloadData(void* pData, CUmemorytype mem_type) = 0;

/*
* SyncBuffer method makes sure that data upload is complete on input/hint
* GPU buffer. It also makes sure that data is ready for download from
* output/cost GPU buffer.
*/
virtual void SyncBuffer() {}

protected:
NvOFBuffer(const NV_OF_BUFFER_DESCRIPTOR& desc, uint32_t elementSize)
: m_hGPUBuffer(nullptr),
m_width(desc.width),
m_elementSize(elementSize),
m_height(desc.height),
m_eBufUsage(desc.bufferUsage),
m_eBufFmt(desc.bufferFormat) {}
NvOFGPUBufferHandle getOFBufferHandle() { return m_hGPUBuffer; }

NvOFGPUBufferHandle m_hGPUBuffer;

private:
uint32_t m_width;
uint32_t m_elementSize;
uint32_t m_height;
NV_OF_BUFFER_USAGE m_eBufUsage;
NV_OF_BUFFER_FORMAT m_eBufFmt;
friend class NvOF;
};

/*
* NvOFAPI is a helper class for loading the library which implements the
* NVOF API. Classes derived from this provide access to the common and
* interface-specific API calls from NVOF API.
*/
class NvOFAPI {
public:
NvOFAPI() { LoadNvOFAPI(); }
virtual ~NvOFAPI() {}

protected:
HMODULE m_hModule;
std::mutex m_lock;

private:
void LoadNvOFAPI();
};

/**
* @brief Base class for different optical flow interfaces
*/
class NvOF {
public:
/**
* @brief NvOF class virtual destructor
*/
virtual ~NvOF(){};

/**
* @brief Create one or more GPU buffers for the specified usage mode
*/
std::vector<NvOFBufferObj> CreateBuffers(NV_OF_BUFFER_USAGE usage,
uint32_t numBuffers);

/**
* @brief Create one or more GPU buffers for the specified width, height and
* usage mode,
*/
std::vector<NvOFBufferObj> CreateBuffers(uint32_t nWidth, uint32_t nHeight,
NV_OF_BUFFER_USAGE usage,
uint32_t numBuffers);

/**
* @brief This function is used to estimate the optical flow from image1 to
* image2.
*/
void Execute(NvOFBuffer* image1, NvOFBuffer* image2,
NvOFBuffer* outputBuffer, NvOFBuffer* hintBuffer = nullptr,
NvOFBuffer* costBuffer = nullptr);

protected:
/**
* @brief NvOF class constructor.
* NvOF class constructor cannot be called directly by the application.
*/
NvOF(uint32_t nWidth, uint32_t nHeight, NV_OF_BUFFER_FORMAT eInBufFmt,
NV_OF_MODE eMode = NV_OF_MODE_OPTICALFLOW,
NV_OF_PERF_LEVEL preset = NV_OF_PERF_LEVEL_SLOW);

public:
void Init(uint32_t nOutGridSize);

/*
* Check for the grid size support by hw
*/
bool CheckGridSize(uint32_t nOutGridSize);

/*
* Retrieves the next minimum grid size supported for the specified grid
* size
*/
bool GetNextMinGridSize(uint32_t nOutGridSize,
uint32_t& nextMinOutGridSize);

private:
/*
* Retrieves the output grid sizes supported
*/
virtual void DoGetOutputGridSizes(uint32_t* vals, uint32_t* size) = 0;

/*
* Initializes the NVOF API.
*/
virtual void DoInit(const NV_OF_INIT_PARAMS& initParams) = 0;

/*
* Executes the estimation of optical flow/stereo disparity between 2
* images.
*/
virtual void DoExecute(const NV_OF_EXECUTE_INPUT_PARAMS& executeInParams,
NV_OF_EXECUTE_OUTPUT_PARAMS& executeOutParams) = 0;

/*
* Allocates one or more GPU buffers.
*/
virtual std::vector<NvOFBufferObj> DoAllocBuffers(
NV_OF_BUFFER_DESCRIPTOR ofBufferDesc, uint32_t elementSize,
uint32_t numBuffers) = 0;

protected:
uint32_t m_nOutGridSize;
NV_OF_PERF_LEVEL m_ePreset;
NV_OF_MODE m_ofMode;
NV_OF_BUFFER_DESCRIPTOR m_inputBufferDesc;
NV_OF_BUFFER_DESCRIPTOR m_outputBufferDesc;
NV_OF_BUFFER_DESCRIPTOR m_costBufferDesc;
NV_OF_BUFFER_DESCRIPTOR m_hintBufferDesc;

uint32_t m_outputElementSize;
uint32_t m_inputElementSize;
uint32_t m_costBufElementSize;
uint32_t m_hintBufElementSize;

NV_OF_INIT_PARAMS m_initParams;
};

#endif

+ 343
- 0
src/opr/impl/nvof/NvOFCuda.cpp View File

@@ -0,0 +1,343 @@
/*
* Copyright 2018-2019 NVIDIA Corporation. All rights reserved.
*
* Please refer to the NVIDIA end user license agreement (EULA) associated
* with this source code for terms and conditions that govern your use of
* this software. Any use, reproduction, disclosure, or distribution of
* this software and related documentation outside the terms of the EULA
* is strictly prohibited.
*
*/

/**
* \file src/opr/impl/nvof/NvOFCuda.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"

#if MGB_CUDA
#ifndef _WIN32
#include <dlfcn.h>
#endif
#include "megbrain/common.h"
#include "NvOFCuda.h"

NvOFCudaAPI::NvOFCudaAPI(CUcontext cuContext, CUstream inputStream, CUstream outputStream)
: m_inputStream(inputStream), m_outputStream(outputStream), m_cuContext(cuContext)
{
typedef NV_OF_STATUS(NVOFAPI *PFNNvOFAPICreateInstanceCuda)(uint32_t apiVer, NV_OF_CUDA_API_FUNCTION_LIST* cudaOf);
#if defined(_WIN32)
PFNNvOFAPICreateInstanceCuda NvOFAPICreateInstanceCuda = (PFNNvOFAPICreateInstanceCuda)GetProcAddress(m_hModule, "NvOFAPICreateInstanceCuda");
#else
PFNNvOFAPICreateInstanceCuda NvOFAPICreateInstanceCuda = (PFNNvOFAPICreateInstanceCuda)dlsym(m_hModule, "NvOFAPICreateInstanceCuda");
#endif
if (!NvOFAPICreateInstanceCuda)
{
mgb_throw(MegBrainError,
"NVOF: Cannot find NvOFAPICreateInstanceCuda() entry in NVOF "
"library err type: NV_OF_ERR_OF_NOT_AVAILABLE");
}

m_ofAPI.reset(new NV_OF_CUDA_API_FUNCTION_LIST());

NVOF_API_CALL(NvOFAPICreateInstanceCuda(NV_OF_API_VERSION, m_ofAPI.get()));
NVOF_API_CALL(m_ofAPI->nvCreateOpticalFlowCuda(m_cuContext, &m_hOF));
NVOF_API_CALL(m_ofAPI->nvOFSetIOCudaStreams(m_hOF, m_inputStream, m_outputStream));
}

NvOFCudaAPI::~NvOFCudaAPI()
{
if (m_ofAPI)
{
m_ofAPI->nvOFDestroy(m_hOF);
}
}

CUstream NvOFCudaAPI::GetCudaStream(NV_OF_BUFFER_USAGE usage)
{
CUstream stream = 0;
if (usage == NV_OF_BUFFER_USAGE_INPUT)
{
stream = m_inputStream;
}
else if ((usage == NV_OF_BUFFER_USAGE_OUTPUT) ||
(usage == NV_OF_BUFFER_USAGE_COST) ||
(usage == NV_OF_BUFFER_USAGE_HINT))
{
stream = m_outputStream;
}
return stream;
}

NvOFObj NvOFCuda::Create(CUcontext cuContext, uint32_t nWidth, uint32_t nHeight,
NV_OF_BUFFER_FORMAT eInBufFmt,
NV_OF_CUDA_BUFFER_TYPE eInBufType,
NV_OF_CUDA_BUFFER_TYPE eOutBufType,
NV_OF_MODE eMode,
NV_OF_PERF_LEVEL preset,
CUstream inputStream,
CUstream outputStream)
{
std::unique_ptr<NvOF> ofObj(new NvOFCuda(cuContext,
nWidth,
nHeight,
eInBufFmt,
eInBufType,
eOutBufType,
eMode,
preset,
inputStream,
outputStream));
return ofObj;
}

NvOFCuda::NvOFCuda(CUcontext cuContext,
uint32_t nWidth,
uint32_t nHeight,
NV_OF_BUFFER_FORMAT eInBufFmt,
NV_OF_CUDA_BUFFER_TYPE eInBufType,
NV_OF_CUDA_BUFFER_TYPE eOutBufType,
NV_OF_MODE eMode,
NV_OF_PERF_LEVEL preset,
CUstream inputStream,
CUstream outputStream)
: NvOF(nWidth, nHeight, eInBufFmt, eMode, preset),
m_cuContext(cuContext),
m_eInBufType(eInBufType),
m_eOutBufType(eOutBufType)
{
m_NvOFAPI = std::make_shared<NvOFCudaAPI>(m_cuContext, inputStream, outputStream);
}

void NvOFCuda::DoGetOutputGridSizes(uint32_t* vals, uint32_t* size)
{
NVOF_API_CALL(m_NvOFAPI->GetAPI()->nvOFGetCaps(m_NvOFAPI->GetHandle(), NV_OF_CAPS_SUPPORTED_OUTPUT_GRID_SIZES, vals, size));
}

void NvOFCuda::DoExecute(const NV_OF_EXECUTE_INPUT_PARAMS& executeInParams,
NV_OF_EXECUTE_OUTPUT_PARAMS& executeOutParams)
{
NVOF_API_CALL(m_NvOFAPI->GetAPI()->nvOFExecute(m_NvOFAPI->GetHandle(), &executeInParams, &executeOutParams));
}

void NvOFCuda::DoInit(const NV_OF_INIT_PARAMS& initParams)
{
NVOF_API_CALL(m_NvOFAPI->GetAPI()->nvOFInit(m_NvOFAPI->GetHandle(), &initParams));
}

NV_OF_CUDA_BUFFER_TYPE NvOFCuda::GetBufferType(NV_OF_BUFFER_USAGE usage)
{
NV_OF_CUDA_BUFFER_TYPE bufferType = NV_OF_CUDA_BUFFER_TYPE_UNDEFINED;
if (usage == NV_OF_BUFFER_USAGE_INPUT)
{
bufferType = m_eInBufType;
}
else if ((usage == NV_OF_BUFFER_USAGE_OUTPUT) ||
(usage == NV_OF_BUFFER_USAGE_COST) ||
(usage == NV_OF_BUFFER_USAGE_HINT))
{
bufferType = m_eOutBufType;
}

return bufferType;
}

std::vector<NvOFBufferObj>
NvOFCuda::DoAllocBuffers(NV_OF_BUFFER_DESCRIPTOR ofBufferDesc,
uint32_t elementSize, uint32_t numBuffers)
{
std::vector<NvOFBufferObj> ofBuffers;
for (uint32_t i = 0; i < numBuffers; ++i)
{
NV_OF_CUDA_BUFFER_TYPE bufferType = GetBufferType(ofBufferDesc.bufferUsage);
ofBuffers.emplace_back(CreateOFBufferObject(ofBufferDesc, elementSize, bufferType).release());
}
return ofBuffers;
}

std::unique_ptr<NvOFBuffer> NvOFCuda::CreateOFBufferObject(const NV_OF_BUFFER_DESCRIPTOR& desc, uint32_t elementSize, NV_OF_CUDA_BUFFER_TYPE bufferType)
{
std::unique_ptr<NvOFBuffer> pBuffer;
if (bufferType == NV_OF_CUDA_BUFFER_TYPE_CUARRAY)
{
pBuffer.reset(new NvOFBufferCudaArray(m_NvOFAPI, desc, elementSize));
}
else
{
pBuffer.reset(new NvOFBufferCudaDevicePtr(m_NvOFAPI, desc, elementSize));
}
return pBuffer;
}

NvOFBufferCudaDevicePtr::NvOFBufferCudaDevicePtr(std::shared_ptr<NvOFCudaAPI> ofAPI, const NV_OF_BUFFER_DESCRIPTOR& desc, uint32_t elementSize) :
NvOFBuffer(desc, elementSize), m_devPtr(0), m_NvOFAPI(ofAPI)
{
m_cuContext = m_NvOFAPI->GetCudaContext();
NVOF_API_CALL(m_NvOFAPI->GetAPI()->nvOFCreateGPUBufferCuda(m_NvOFAPI->GetHandle(),
&desc,
NV_OF_CUDA_BUFFER_TYPE_CUDEVICEPTR,
&m_hGPUBuffer));
m_devPtr = m_NvOFAPI->GetAPI()->nvOFGPUBufferGetCUdeviceptr(m_hGPUBuffer);
NVOF_API_CALL(m_NvOFAPI->GetAPI()->nvOFGPUBufferGetStrideInfo(m_hGPUBuffer, &m_strideInfo));
}

NvOFBufferCudaDevicePtr::~NvOFBufferCudaDevicePtr()
{
m_NvOFAPI->GetAPI()->nvOFDestroyGPUBufferCuda(m_hGPUBuffer);
}

void NvOFBufferCudaDevicePtr::UploadData(const void* pData,
CUmemorytype mem_type) {
CUstream stream = m_NvOFAPI->GetCudaStream(getBufferUsage());
CUDA_DRVAPI_CALL(cuCtxPushCurrent(m_cuContext));
CUDA_MEMCPY2D cuCopy2d;
memset(&cuCopy2d, 0, sizeof(cuCopy2d));
cuCopy2d.WidthInBytes = getWidth()* getElementSize();
mgb_assert(
CU_MEMORYTYPE_HOST == mem_type || CU_MEMORYTYPE_DEVICE == mem_type,
"do not imp mem type!!!");
cuCopy2d.srcMemoryType = mem_type;
if (CU_MEMORYTYPE_HOST == mem_type) {
cuCopy2d.srcHost = pData;
} else if (CU_MEMORYTYPE_DEVICE == mem_type) {
cuCopy2d.srcDevice = (CUdeviceptr)pData;
}
cuCopy2d.srcPitch = cuCopy2d.WidthInBytes;
cuCopy2d.dstMemoryType = CU_MEMORYTYPE_DEVICE;
cuCopy2d.dstDevice = getCudaDevicePtr();
cuCopy2d.dstPitch = m_strideInfo.strideInfo[0].strideXInBytes;
cuCopy2d.Height = getHeight();
CUDA_DRVAPI_CALL(cuMemcpy2DAsync(&cuCopy2d, stream));

if (getBufferFormat() == NV_OF_BUFFER_FORMAT_NV12)
{
cuCopy2d.Height = (getHeight() + 1)/2;
cuCopy2d.srcHost = ((const uint8_t *)pData + (cuCopy2d.srcPitch * cuCopy2d.Height));
cuCopy2d.dstY = m_strideInfo.strideInfo[0].strideYInBytes;
CUDA_DRVAPI_CALL(cuMemcpy2DAsync(&cuCopy2d, stream));
}
CUDA_DRVAPI_CALL(cuCtxPopCurrent(&m_cuContext));
}

void NvOFBufferCudaDevicePtr::DownloadData(void* pData, CUmemorytype mem_type) {
CUstream stream = m_NvOFAPI->GetCudaStream(getBufferUsage());
CUDA_DRVAPI_CALL(cuCtxPushCurrent(m_cuContext));
CUDA_MEMCPY2D cuCopy2d;
memset(&cuCopy2d, 0, sizeof(cuCopy2d));
cuCopy2d.WidthInBytes = getWidth() * getElementSize();

mgb_assert(
CU_MEMORYTYPE_HOST == mem_type || CU_MEMORYTYPE_DEVICE == mem_type,
"do not imp mem type!!!");
cuCopy2d.dstMemoryType = mem_type;
if (CU_MEMORYTYPE_HOST == mem_type) {
cuCopy2d.dstHost = pData;
} else if (CU_MEMORYTYPE_DEVICE == mem_type) {
cuCopy2d.dstDevice = (CUdeviceptr)pData;
}
cuCopy2d.dstPitch = cuCopy2d.WidthInBytes;
cuCopy2d.srcMemoryType = CU_MEMORYTYPE_DEVICE;
cuCopy2d.srcDevice = getCudaDevicePtr();
cuCopy2d.srcPitch = m_strideInfo.strideInfo[0].strideXInBytes;
cuCopy2d.Height = getHeight();
CUDA_DRVAPI_CALL(cuMemcpy2DAsync(&cuCopy2d, stream));
if (getBufferFormat() == NV_OF_BUFFER_FORMAT_NV12)
{
cuCopy2d.Height = (getHeight() + 1) / 2;
cuCopy2d.dstHost = ((uint8_t *)pData + (cuCopy2d.dstPitch * cuCopy2d.Height));
cuCopy2d.srcY = m_strideInfo.strideInfo[0].strideYInBytes;
CUDA_DRVAPI_CALL(cuMemcpy2DAsync(&cuCopy2d, stream));
}
CUDA_DRVAPI_CALL(cuStreamSynchronize(stream));
CUDA_DRVAPI_CALL(cuCtxPopCurrent(&m_cuContext));
}

NvOFBufferCudaArray::NvOFBufferCudaArray(std::shared_ptr<NvOFCudaAPI> ofAPI, const NV_OF_BUFFER_DESCRIPTOR& desc, uint32_t elementSize) :
NvOFBuffer(desc, elementSize), m_cuArray(0), m_NvOFAPI(ofAPI)
{
m_cuContext = m_NvOFAPI->GetCudaContext();
NVOF_API_CALL(m_NvOFAPI->GetAPI()->nvOFCreateGPUBufferCuda(m_NvOFAPI->GetHandle(),
&desc,
NV_OF_CUDA_BUFFER_TYPE_CUARRAY,
&m_hGPUBuffer));
m_cuArray = m_NvOFAPI->GetAPI()->nvOFGPUBufferGetCUarray(m_hGPUBuffer);
NVOF_API_CALL(m_NvOFAPI->GetAPI()->nvOFGPUBufferGetStrideInfo(m_hGPUBuffer, &m_strideInfo));
}

NvOFBufferCudaArray::~NvOFBufferCudaArray()
{
m_NvOFAPI->GetAPI()->nvOFDestroyGPUBufferCuda(m_hGPUBuffer);
}

void NvOFBufferCudaArray::UploadData(const void* pData, CUmemorytype mem_type) {
CUstream stream = m_NvOFAPI->GetCudaStream(getBufferUsage());
CUDA_DRVAPI_CALL(cuCtxPushCurrent(m_cuContext));
CUDA_MEMCPY2D cuCopy2d;
memset(&cuCopy2d, 0, sizeof(cuCopy2d));
cuCopy2d.WidthInBytes = getWidth() * getElementSize();
mgb_assert(
CU_MEMORYTYPE_HOST == mem_type || CU_MEMORYTYPE_DEVICE == mem_type,
"do not imp mem type!!!");
cuCopy2d.srcMemoryType = mem_type;
if (CU_MEMORYTYPE_HOST == mem_type) {
cuCopy2d.srcHost = pData;
} else if (CU_MEMORYTYPE_DEVICE == mem_type) {
cuCopy2d.srcDevice = (CUdeviceptr)pData;
}
cuCopy2d.srcPitch = cuCopy2d.WidthInBytes;
cuCopy2d.dstMemoryType = CU_MEMORYTYPE_ARRAY;
cuCopy2d.dstArray= getCudaArray();
cuCopy2d.Height = getHeight();
CUDA_DRVAPI_CALL(cuMemcpy2DAsync(&cuCopy2d, stream));

if (getBufferFormat() == NV_OF_BUFFER_FORMAT_NV12)
{
cuCopy2d.Height = (getHeight() + 1) / 2;
cuCopy2d.srcHost = ((const uint8_t *)pData + (cuCopy2d.srcPitch * cuCopy2d.Height));
cuCopy2d.dstY = m_strideInfo.strideInfo[0].strideYInBytes;
CUDA_DRVAPI_CALL(cuMemcpy2DAsync(&cuCopy2d, stream));
}
CUDA_DRVAPI_CALL(cuCtxPopCurrent(&m_cuContext));
}

void NvOFBufferCudaArray::DownloadData(void* pData, CUmemorytype mem_type) {
CUstream stream = m_NvOFAPI->GetCudaStream(getBufferUsage());
CUDA_DRVAPI_CALL(cuCtxPushCurrent(m_cuContext));
CUDA_MEMCPY2D cuCopy2d;
memset(&cuCopy2d, 0, sizeof(cuCopy2d));
cuCopy2d.WidthInBytes = getWidth() * getElementSize();

mgb_assert(
CU_MEMORYTYPE_HOST == mem_type || CU_MEMORYTYPE_DEVICE == mem_type,
"do not imp mem type!!!");
cuCopy2d.dstMemoryType = mem_type;
if (CU_MEMORYTYPE_HOST == mem_type) {
cuCopy2d.dstHost = pData;
} else if (CU_MEMORYTYPE_DEVICE == mem_type) {
cuCopy2d.dstDevice = (CUdeviceptr)pData;
}
cuCopy2d.dstPitch = cuCopy2d.WidthInBytes;
cuCopy2d.srcMemoryType = CU_MEMORYTYPE_ARRAY;
cuCopy2d.srcArray = getCudaArray();
cuCopy2d.Height = getHeight();
CUDA_DRVAPI_CALL(cuMemcpy2DAsync(&cuCopy2d, stream));
if (getBufferFormat() == NV_OF_BUFFER_FORMAT_NV12)
{
cuCopy2d.Height = (getHeight() + 1) / 2;
cuCopy2d.dstHost = ((uint8_t *)pData + (cuCopy2d.dstPitch * cuCopy2d.Height));
cuCopy2d.srcY = m_strideInfo.strideInfo[0].strideYInBytes;
CUDA_DRVAPI_CALL(cuMemcpy2DAsync(&cuCopy2d, stream));
}
CUDA_DRVAPI_CALL(cuStreamSynchronize(stream));
CUDA_DRVAPI_CALL(cuCtxPopCurrent(&m_cuContext));
}

#endif

+ 178
- 0
src/opr/impl/nvof/NvOFCuda.h View File

@@ -0,0 +1,178 @@
/*
* Copyright 2018-2019 NVIDIA Corporation. All rights reserved.
*
* Please refer to the NVIDIA end user license agreement (EULA) associated
* with this source code for terms and conditions that govern your use of
* this software. Any use, reproduction, disclosure, or distribution of
* this software and related documentation outside the terms of the EULA
* is strictly prohibited.
*
*/

/**
* \file src/opr/impl/nvof/NvOFCuda.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"

#if MGB_CUDA
#pragma once
#include <memory>
#include "cuda.h"
#include "nvOpticalFlowCommon.h"
#include "nvOpticalFlowCuda.h"
#include "NvOF.h"

#define CUDA_DRVAPI_CALL(call) \
do { \
CUresult err__ = call; \
if (err__ != CUDA_SUCCESS) { \
const char* szErrName = NULL; \
cuGetErrorName(err__, &szErrName); \
std::ostringstream errorLog; \
errorLog << "CUDA driver API error " << szErrName; \
std::cout << "Exception: " << __FILE__ << ":" << __LINE__ << ":" \
<< errorLog.str() << std::endl; \
mgb_throw(MegBrainError, "CUDA_DRVAPI_CALL ERROR"); \
} \
} while (0)

class NvOFCudaAPI : public NvOFAPI {
public:
NvOFCudaAPI(CUcontext cuContext, CUstream inputStream = nullptr, CUstream outputStream = nullptr);
~NvOFCudaAPI();

NV_OF_CUDA_API_FUNCTION_LIST* GetAPI()
{
std::lock_guard<std::mutex> lock(m_lock);
return m_ofAPI.get();
}

CUcontext GetCudaContext() { return m_cuContext; }
NvOFHandle GetHandle() { return m_hOF; }
CUstream GetCudaStream(NV_OF_BUFFER_USAGE usage);
private:
CUstream m_inputStream;
CUstream m_outputStream;
NvOFHandle m_hOF;
std::unique_ptr<NV_OF_CUDA_API_FUNCTION_LIST> m_ofAPI;
CUcontext m_cuContext;
};

/**
* @brief Optical Flow for the CUDA interface
*/
class NvOFCuda : public NvOF
{
public:
static NvOFObj Create(CUcontext cuContext, uint32_t nWidth, uint32_t nHeight,
NV_OF_BUFFER_FORMAT eInBufFmt,
NV_OF_CUDA_BUFFER_TYPE eInBufType,
NV_OF_CUDA_BUFFER_TYPE eOutBufType,
NV_OF_MODE eMode,
NV_OF_PERF_LEVEL preset,
CUstream inputStream = nullptr,
CUstream outputStream = nullptr);
~NvOFCuda() {};

private:
NvOFCuda(CUcontext cuContext,
uint32_t nWidth,
uint32_t nHeight,
NV_OF_BUFFER_FORMAT eInBufFmt,
NV_OF_CUDA_BUFFER_TYPE eInBufType,
NV_OF_CUDA_BUFFER_TYPE eOutBufType,
NV_OF_MODE eMode,
NV_OF_PERF_LEVEL preset,
CUstream inputStream = nullptr,
CUstream outputStream = nullptr);
/**
* @brief This function is used to retrieve supported grid size for output.
* This function is an override of pure virtual function NvOF::DoGetOutputGridSizes().
*/
virtual void DoGetOutputGridSizes(uint32_t* vals, uint32_t* size) override;

/**
* @brief This function is used to initialize the OF engine.
* This function is an override of pure virtual function NvOF::DoInit().
*/
virtual void DoInit(const NV_OF_INIT_PARAMS& initParams) override;

/**
* @brief This function is used to estimate the optical flow between 2 images.
* This function is an override of pure virtual function NvOF::DoExecute().
*/
virtual void DoExecute(const NV_OF_EXECUTE_INPUT_PARAMS& executeInParams, NV_OF_EXECUTE_OUTPUT_PARAMS& executeOutParams) override;

/**
* @brief This function is used to allocate buffers used for optical flow estimation
* using the cuda interface. This function is an override of pure virtual function
* NvOF::DoAllocBuffers().
*/
virtual std::vector<NvOFBufferObj> DoAllocBuffers(NV_OF_BUFFER_DESCRIPTOR ofBufferDesc,
uint32_t elementSize, uint32_t numBuffers) override;

/**
* @brief This a helper function for allocating NvOFBuffer objects using the cuda
* interface.
*/
std::unique_ptr<NvOFBuffer> CreateOFBufferObject(const NV_OF_BUFFER_DESCRIPTOR& desc, uint32_t elementSize, NV_OF_CUDA_BUFFER_TYPE bufferType);
NV_OF_CUDA_BUFFER_TYPE GetBufferType(NV_OF_BUFFER_USAGE usage);

private:
CUcontext m_cuContext;
std::shared_ptr<NvOFCudaAPI> m_NvOFAPI;
NV_OF_CUDA_BUFFER_TYPE m_eInBufType;
NV_OF_CUDA_BUFFER_TYPE m_eOutBufType;
};

/*
* A wrapper over an NvOFGPUBufferHandle which has been created with buffer
* type NV_OF_CUDA_BUFFER_TYPE_CUDEVICEPTR.
*/
class NvOFBufferCudaDevicePtr : public NvOFBuffer
{
public:
~NvOFBufferCudaDevicePtr();
CUdeviceptr getCudaDevicePtr() { return m_devPtr; }
virtual void UploadData(const void* pData, CUmemorytype mem_type) override;
virtual void DownloadData(void* pData, CUmemorytype mem_type) override;
NV_OF_CUDA_BUFFER_STRIDE_INFO getStrideInfo() { return m_strideInfo; }
private:
NvOFBufferCudaDevicePtr(std::shared_ptr<NvOFCudaAPI> ofAPI, const NV_OF_BUFFER_DESCRIPTOR& desc, uint32_t elementSize);
CUdeviceptr m_devPtr;
CUcontext m_cuContext;
NV_OF_CUDA_BUFFER_STRIDE_INFO m_strideInfo;
std::shared_ptr<NvOFCudaAPI> m_NvOFAPI;
friend class NvOFCuda;
};

/*
* A wrapper over an NvOFGPUBufferHandle which has been created with buffer
* type NV_OF_CUDA_BUFFER_TYPE_CUARRAY.
*/
class NvOFBufferCudaArray : public NvOFBuffer
{
public:
~NvOFBufferCudaArray();
virtual void UploadData(const void* pData, CUmemorytype mem_type) override;
virtual void DownloadData(void* pData, CUmemorytype mem_type) override;
CUarray getCudaArray() { return m_cuArray; }
private:
NvOFBufferCudaArray(std::shared_ptr<NvOFCudaAPI> ofAPI, const NV_OF_BUFFER_DESCRIPTOR& desc, uint32_t elementSize);
CUarray m_cuArray;
CUcontext m_cuContext;
NV_OF_CUDA_BUFFER_STRIDE_INFO m_strideInfo;
std::shared_ptr<NvOFCudaAPI> m_NvOFAPI;
friend class NvOFCuda;
};

#endif

+ 55
- 0
src/opr/impl/nvof/NvOFDefines.h View File

@@ -0,0 +1,55 @@
/*
* Copyright 2018 NVIDIA Corporation. All rights reserved.
*
* Please refer to the NVIDIA end user license agreement (EULA) associated
* with this source code for terms and conditions that govern your use of
* this software. Any use, reproduction, disclosure, or distribution of
* this software and related documentation outside the terms of the EULA
* is strictly prohibited.
*
*/

/**
* \file src/opr/impl/nvof/NvOFDefines.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"

#if MGB_CUDA
#pragma once
#ifdef _WIN32
#define NOMINMAX
#include <Windows.h>
//FIXME: mgb code redefine CALLBACK, some win32 API will be disable
#undef CALLBACK
#undef CONST
#define DIR_SEP "\\"
#else
#define HMODULE void *
#define _stricmp strcasecmp
#define DIR_SEP "/"
#endif
#include <memory>

class NvOF;
class NvOFBuffer;

/**
* @brief A managed pointer wrapper for NvOF class objects
*/
using NvOFObj = std::unique_ptr<NvOF>;

/**
* @brief A managed pointer wrapper for NvOFBuffer class objects
*/
using NvOFBufferObj = std::unique_ptr<NvOFBuffer>;

#endif

+ 202
- 0
src/opr/impl/nvof/denseflownvidia.cpp View File

@@ -0,0 +1,202 @@
/**
* \file src/opr/impl/nvof/denseflownvidia.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"

#if MGB_CUDA
#include <mutex>
#include <vector>
#include "megbrain/common.h"
#include "denseflownvidia.h"

NVFlowExtractor::NVFlowExtractor(int device_id, std::vector<size_t>& shape,
uint32_t preset, bool use_cuda_stream,
bool debug) {
batch_size = shape[0];
m_width = shape[3];
m_height = shape[2];
debug_flag = debug;
m_temporal_size = shape[1];
m_use_cuda_stream = use_cuda_stream;
out_width = (m_width + m_out_grid_size - 1) / m_out_grid_size;
out_height = (m_height + m_out_grid_size - 1) / m_out_grid_size;
m_width_in_blocks = (m_width + m_blockSizeX - 1) / m_blockSizeX;
m_height_in_blocks = (m_height + m_blockSizeY - 1) / m_blockSizeY;
out_size = out_width * out_height * 2;
m_device_id = device_id;

std::unordered_map<uint32_t, NV_OF_PERF_LEVEL> preset_map = {
{0, NV_OF_PERF_LEVEL_SLOW},
{1, NV_OF_PERF_LEVEL_MEDIUM},
{2, NV_OF_PERF_LEVEL_FAST}};

_preset = preset;
auto search = preset_map.find(_preset);
if (search == preset_map.end()) {
mgb_throw(MegBrainError, "NVOF: invalid preset level! err type: NV_OF_ERR_INVALID_PARAM");
}
perf_preset = search->second;
}

void NVFlowExtractor::create_nvof_instances(int height, int width) {
nv_optical_flow = NvOFCuda::Create(cu_context, width, height, buffer_format,
input_buffer_type, output_buffer_type,
NV_OF_MODE_OPTICALFLOW, perf_preset,
input_stream, output_stream);
nv_optical_flow->Init(m_out_grid_size);
input_buffers = nv_optical_flow->CreateBuffers(
NV_OF_BUFFER_USAGE_INPUT, buffer_pool_size * batch_size);
output_buffers = nv_optical_flow->CreateBuffers(
NV_OF_BUFFER_USAGE_OUTPUT, (buffer_pool_size - 1) * batch_size);
}

void NVFlowExtractor::init_nvof_engine() {
std::lock_guard<std::mutex> lock(m_lock);
if (init_flag == false) {
set_device(m_device_id);
if (cuCtxCreate(&cu_context, 0, cu_device)) {
mgb_log_warn(
"nvof: create ctx failed, fallback to get current ctx");
CUDA_DRVAPI_CALL(cuCtxGetCurrent(&cu_context));
}

if (m_use_cuda_stream) {
CUDA_DRVAPI_CALL(cuStreamCreate(&input_stream, CU_STREAM_DEFAULT));
CUDA_DRVAPI_CALL(cuStreamCreate(&output_stream, CU_STREAM_DEFAULT));
}
create_nvof_instances(m_height, m_width);
init_flag = true;
}
}

NVFlowExtractor::~NVFlowExtractor() {
if (debug_flag) {
mgb_log_debug("%s: %d start", __FUNCTION__, __LINE__);
}

if (m_use_cuda_stream) {
cuStreamDestroy(output_stream);
output_stream = nullptr;
cuStreamDestroy(input_stream);
input_stream = nullptr;
}

if (debug_flag) {
mgb_log_debug("%s: %d end", __FUNCTION__, __LINE__);
}
}

void NVFlowExtractor::set_device(int dev_id) {
int nGpu = 0;

if (debug_flag) {
mgb_log_warn("config nvof gpu device id: %d", dev_id);
}

CUDA_DRVAPI_CALL(cuInit(0));
CUDA_DRVAPI_CALL(cuDeviceGetCount(&nGpu));
if (dev_id < 0 || dev_id >= nGpu) {
mgb_log_warn("GPU ordinal out of range. Should be with in [0, %d]",
nGpu - 1);
mgb_throw(MegBrainError, "NVOF: GPU Setting Error! err type: NV_OF_ERR_GENERIC");
}
CUDA_DRVAPI_CALL(cuDeviceGet(&cu_device, dev_id));
}

CUmemorytype NVFlowExtractor::get_mem_type(CUdeviceptr p) {
unsigned int mem_type;
auto ret = cuPointerGetAttribute(&mem_type,
CU_POINTER_ATTRIBUTE_MEMORY_TYPE, p);

if (CUDA_SUCCESS == ret) {
mgb_assert(
CU_MEMORYTYPE_DEVICE == mem_type ||
CU_MEMORYTYPE_HOST == mem_type,
"only imp CU_MEMORYTYPE_HOST or CU_MEMORYTYPE_DEVICE mem type");
} else {
mgb_log_warn(
"nvof call cuPointerGetAttribute err!!, may init nvof opr on "
"cpu comp_node, force set mem type to CU_MEMORYTYPE_HOST");
mem_type = CU_MEMORYTYPE_HOST;
}

return static_cast<CUmemorytype_enum>(mem_type);
}

void NVFlowExtractor::extract_flow(unsigned char* frames,
std::vector<size_t>& shape,
int16_t* result_out_ptr) {
auto batch_size = shape[0];
auto temporal_size = shape[1];
auto height = shape[2];
auto width = shape[3];
auto channel = shape[4];
auto temporal_len = height * width * channel;
auto batch_len = temporal_size * height * width * channel;

init_nvof_engine();

auto src_mem_type = get_mem_type(reinterpret_cast<CUdeviceptr>(frames));
auto out_mem_type =
get_mem_type(reinterpret_cast<CUdeviceptr>(result_out_ptr));

if ((height != m_height || width != m_width) ||
(m_temporal_size != temporal_size)) {
mgb_log_warn("We do not support dynamic shape at mgb side");
mgb_throw(MegBrainError, "NVOF: Nvof err shap!!!! err type: NV_OF_ERR_GENERIC");
}

for (size_t batch_idx = 0; batch_idx < batch_size; batch_idx++) {
auto input_buffer_batch_offsect = buffer_pool_size * batch_idx;
auto output_buffer_batch_offsect = (buffer_pool_size - 1) * batch_idx;
input_buffers[input_buffer_batch_offsect]->UploadData(
(unsigned char*)(frames + batch_idx * batch_len), src_mem_type);

for (size_t temporal_idx = 1; temporal_idx < temporal_size;
temporal_idx++) {
input_buffers[input_buffer_batch_offsect +
temporal_idx % buffer_pool_size]
->UploadData(
(unsigned char*)(frames + batch_idx * batch_len +
temporal_idx * temporal_len),
src_mem_type);

nv_optical_flow->Execute(
input_buffers[input_buffer_batch_offsect +
(temporal_idx - 1) % buffer_pool_size]
.get(),
input_buffers[input_buffer_batch_offsect +
temporal_idx % buffer_pool_size]
.get(),
output_buffers[output_buffer_batch_offsect +
(temporal_idx - 1) % (buffer_pool_size - 1)]
.get(),
nullptr, nullptr);

output_buffers[output_buffer_batch_offsect +
(temporal_idx - 1) % (buffer_pool_size - 1)]
->DownloadData(
result_out_ptr +
batch_idx * (temporal_size - 1) * out_size +
(temporal_idx - 1) * out_size,
out_mem_type);
}
}

CUDA_DRVAPI_CALL(cuCtxSynchronize());
}

float NVFlowExtractor::get_precision() {
return m_precision;
}

#endif

+ 80
- 0
src/opr/impl/nvof/denseflownvidia.h View File

@@ -0,0 +1,80 @@
/**
* \file src/opr/impl/nvof/denseflownvidia.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"

#if MGB_CUDA
#pragma once
#include <cuda.h>
#include <fstream>
#include <iostream>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <vector>
#include "NvOFCuda.h"

class NVFlowExtractor {
public:
NVFlowExtractor(int device_id, std::vector<size_t>& shape,
uint32_t preset, bool use_cuda_stream, bool debug);
void create_nvof_instances(int height, int width);
~NVFlowExtractor();
void set_device(int dev_id);
void init_memory(int batch_size, int temporal_size);
void extract_flow(unsigned char* frames, std::vector<size_t>&, int16_t*);
CUmemorytype get_mem_type(CUdeviceptr);
float get_precision();
void init_nvof_engine();

private:
int buffer_pool_size = 6;
bool debug_flag = false;
bool m_use_cuda_stream = false;
bool init_flag = false;
size_t m_device_id = 0;
float m_precision = 32.0f;
uint32_t _preset = 1;
size_t batch_size = 0;
size_t out_size = 0;
size_t m_width = 0;
size_t m_height = 0;
size_t m_temporal_size = 0;
size_t out_width = 0;
size_t out_height = 0;
size_t m_width_in_blocks = 0;
size_t m_height_in_blocks = 0;
size_t m_blockSizeX = 4;
size_t m_blockSizeY = 4;

NV_OF_PERF_LEVEL perf_preset = NV_OF_PERF_LEVEL_MEDIUM;
NV_OF_BUFFER_FORMAT buffer_format = NV_OF_BUFFER_FORMAT_ABGR8;
NV_OF_CUDA_BUFFER_TYPE input_buffer_type =
NV_OF_CUDA_BUFFER_TYPE_CUDEVICEPTR;
NV_OF_CUDA_BUFFER_TYPE output_buffer_type =
NV_OF_CUDA_BUFFER_TYPE_CUDEVICEPTR;
NV_OF_OUTPUT_VECTOR_GRID_SIZE m_out_grid_size =
NV_OF_OUTPUT_VECTOR_GRID_SIZE_4;

NvOFObj nv_optical_flow;
CUdevice cu_device = 0;
CUcontext cu_context = nullptr;
CUstream input_stream = nullptr;
CUstream output_stream = nullptr;
std::vector<NvOFBufferObj> input_buffers;
std::vector<NvOFBufferObj> output_buffers;

protected:
std::mutex m_lock;
};

#endif

+ 510
- 0
src/opr/impl/nvof/nvOpticalFlowCommon.h View File

@@ -0,0 +1,510 @@
/*
* This copyright notice applies to this header file only:
*
* Copyright (c) 2018 NVIDIA Corporation
*
* Permission is hereby granted, free of charge, to any person
* obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without
* restriction, including without limitation the rights to use,
* copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the software, and to permit persons to whom the
* software is furnished to do so, subject to the following
* conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
* OTHER DEALINGS IN THE SOFTWARE.
*/
/**
* \file nvOpticalFlowCommon.h
* NVIDIA GPUs - Turing and above contains a hardware-based optical flow engine
* which provides fully-accelerated hardware-based optical flow and stereo estimation.
* nvOpticalFlowCommon.h provides enums, structure definitions and function prototypes which are common across different devices,
* nvOpticalFlowCommon.h uses #pragma directives to pack structure members with one byte alignment.
* \date 2018
* nvOpticalFlowCommon.h provides common enums, structure definitions and function prototypes.
*/

/**
* \file src/opr/impl/nvof/nvOpticalFlowCommon.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"

#if MGB_CUDA
#ifndef _NV_OPTICALFLOW_COMMON_H_
#define _NV_OPTICALFLOW_COMMON_H_
#if defined(_MSC_VER_) && (_MSC_VER_ < 1600)
#ifndef _STDINT
typedef __int32 int32_t;
typedef unsigned __int32 uint32_t;
typedef __int64 int64_t;
typedef unsigned __int64 uint64_t;
typedef signed char int8_t;
typedef unsigned char uint8_t;
typedef short int16_t;
typedef unsigned short uint16_t;
#endif
#else
#include <stdint.h>
#endif

#ifdef _WIN32
#define NVOFAPI __stdcall
#else
#define NVOFAPI
#endif
#define NV_OF_API_MAJOR_VERSION 1
#define NV_OF_API_MINOR_VERSION 1
#define NV_OF_API_VERSION (uint16_t)((NV_OF_API_MAJOR_VERSION << 4) | NV_OF_API_MINOR_VERSION)
#define MIN_ERROR_STRING_SIZE 80

#if defined(__cplusplus)
extern "C"
{
#endif /* __cplusplus */

typedef struct NvOFHandle_st *NvOFHandle;
typedef struct NvOFGPUBufferHandle_st *NvOFGPUBufferHandle;
typedef struct NVOFPrivDataHandle_st *NvOFPrivDataHandle;

/**
* Supported error codes
*/
typedef enum _NV_OF_STATUS
{
/**
* This indicates that API call returned with no errors.
*/
NV_OF_SUCCESS,

/**
* This indicates that HW Optical flow functionality is not supported
*/
NV_OF_ERR_OF_NOT_AVAILABLE,

/**
* This indicates that device passed by the client is not supported.
*/
NV_OF_ERR_UNSUPPORTED_DEVICE,

/**
* This indicates that device passed to the API call is no longer available and
* needs to be reinitialized.
*/
NV_OF_ERR_DEVICE_DOES_NOT_EXIST,

/**
* This indicates that one or more of the pointers passed to the API call
* is invalid.
*/
NV_OF_ERR_INVALID_PTR,

/**
* This indicates that one or more of the parameter passed to the API call
* is invalid.
*/
NV_OF_ERR_INVALID_PARAM,

/**
* This indicates that an API call was made in wrong sequence/order.
*/
NV_OF_ERR_INVALID_CALL,

/**
* This indicates that an invalid struct version was used by the client.
*/
NV_OF_ERR_INVALID_VERSION,

/**
* This indicates that the API call failed because it was unable to allocate
* enough memory to perform the requested operation.
*/
NV_OF_ERR_OUT_OF_MEMORY,

/**
* This indicates that the OF session has not been initialized with
* ::NvOFInit() or that initialization has failed.
*/
NV_OF_ERR_NOT_INITIALIZED,

/**
* This indicates that an unsupported parameter was passed by the client.
*/
NV_OF_ERR_UNSUPPORTED_FEATURE,

/**
* This indicates that an unknown internal error has occurred.
*/
NV_OF_ERR_GENERIC,
} NV_OF_STATUS;

/**
* Supported bool values
*/
typedef enum _NV_OF_BOOL
{
NV_OF_FALSE = 0, /* < Represents false bool value */
NV_OF_TRUE = !NV_OF_FALSE /* < Represents true bool value */
} NV_OF_BOOL;

/**
* Supported optical flow and stereo disparity capability values.
*/
typedef enum _NV_OF_CAPS
{
NV_OF_CAPS_SUPPORTED_OUTPUT_GRID_SIZES, /**< Indicates supported values of ::NV_OF_OUTPUT_VECTOR_GRID_SIZE,
::NV_OF_INIT_PARAMS::outGridSize should be set with a supported output gridsize. */
NV_OF_CAPS_SUPPORTED_HINT_GRID_SIZES, /**< Indicates supported values of ::NV_OF_HINT_VECTOR_GRID_SIZE,
::NV_OF_INIT_PARAMS::hintGridSize should be set with a supported hint gridsize. */
NV_OF_CAPS_SUPPORT_HINT_WITH_OF_MODE, /**< Indicates external hint support for ::NV_OF_MODE_OPTICALFLOW mode.
0: External hint not supported for ::NV_OF_MODE_OPTICALFLOW mode.
1: External hint is supported for ::NV_OF_MODE_OPTICALFLOW mode. */
NV_OF_CAPS_SUPPORT_HINT_WITH_ST_MODE /**< Indicates external hint support for ::NV_OF_MODE_STEREODISPARITY mode.
0: External hint not supported for ::NV_OF_MODE_STEREODISPARITY mode.
1: External hint is supported for ::NV_OF_MODE_STEREODISPARITY mode. */
} NV_OF_CAPS;

/**
* Supported optical flow/stereo disparity performance levels.
*/
typedef enum _NV_OF_PERF_LEVEL
{
NV_OF_PERF_LEVEL_UNDEFINED,
NV_OF_PERF_LEVEL_SLOW = 5, /**< Slow perf level results in lowest performance and best quality */
NV_OF_PERF_LEVEL_MEDIUM = 10, /**< Medium perf level results in low performance and medium quality */
NV_OF_PERF_LEVEL_FAST = 20, /**< Fast perf level results in high performance and low quality */
NV_OF_PERF_LEVEL_MAX
} NV_OF_PERF_LEVEL;

/**
* Supported grid size for output buffer ::NV_OF_EXECUTE_PARAMS::outputBuffer.
* Client should set ::NV_OF_INIT_PARAMS::outGridSize with ::NV_OF_OUTPUT_VECTOR_GRID_SIZE values.
*/
typedef enum _NV_OF_OUTPUT_VECTOR_GRID_SIZE
{
NV_OF_OUTPUT_VECTOR_GRID_SIZE_UNDEFINED,
NV_OF_OUTPUT_VECTOR_GRID_SIZE_4 = 4, /**< Output buffer grid size is 4x4 */
NV_OF_OUTPUT_VECTOR_GRID_SIZE_MAX
} NV_OF_OUTPUT_VECTOR_GRID_SIZE;

/**
* Expected grid size for optional paramater ::NV_OF_EXECUTE_PARAMS::externalHints buffer.
* Client should set ::NV_OF_INIT_PARAMS::hintGridSize with ::NV_OF_HINT_VECTOR_GRID_SIZE values.
*/
typedef enum _NV_OF_HINT_VECTOR_GRID_SIZE
{
NV_OF_HINT_VECTOR_GRID_SIZE_UNDEFINED,
NV_OF_HINT_VECTOR_GRID_SIZE_4 = 4, /**< Hint buffer grid size is 4x4.*/
NV_OF_HINT_VECTOR_GRID_SIZE_8 = 8, /**< Hint buffer grid size is 8x8.*/
NV_OF_HINT_VECTOR_GRID_SIZE_MAX
} NV_OF_HINT_VECTOR_GRID_SIZE;

/**
* ::NV_OF_MODE enum define values for Optical flow and Stereo disparity modes.
* Client need to set ::NV_OF_INIT_PARAMS::mode with ::NV_OF_MODE values.
* For the ::NV_OF_MODE_OPTICALFLOW mode, the buffer format for ::NV_OF_EXECUTE_PARAMS::externalHints
* and ::NV_OF_EXECUTE_PARAMS::outputBuffer is ::NV_OF_FLOW_VECTOR.
* For the ::NV_OF_MODE_STEREODISPARITY mode, the buffer format for ::NV_OF_EXECUTE_PARAMS::externalHints
* and ::NV_OF_EXECUTE_PARAMS::outputBuffer is ::NV_OF_STEREO_DISPARITY.
*/
typedef enum _NV_OF_MODE
{
NV_OF_MODE_UNDEFINED,
NV_OF_MODE_OPTICALFLOW, /**< Calculate optical flow between two frames. */
NV_OF_MODE_STEREODISPARITY, /**< Calculate disparity between Stereo view pair. */
NV_OF_MODE_MAX
} NV_OF_MODE;

/**
* Supported buffer type for ::NvOFGPUBufferHandle allocation.
* Client need to set NV_OF_CREATE_BUFFER::bufferUsage with ::NV_OF_BUFFER_USAGE enum values.
*/
typedef enum _NV_OF_BUFFER_USAGE
{
NV_OF_BUFFER_USAGE_UNDEFINED,
NV_OF_BUFFER_USAGE_INPUT, /**< Input buffer type is used to allocate ::NV_OF_INPUT_EXECUTE_PARAMS::inputFrame,
::NV_OF_INPUT_EXECUTE_PARAMS::referenceFrame. */
NV_OF_BUFFER_USAGE_OUTPUT, /**< Output buffer type is used to allocate ::NV_OF_OUTPUT_EXECUTE_PARAMS::outputBuffer. */
NV_OF_BUFFER_USAGE_HINT, /**< Hint buffer type is used to allocate ::NV_OF_INPUT_EXECUTE_PARAMS::externalHints.*/
NV_OF_BUFFER_USAGE_COST, /**< Cost buffer type is used to allocate ::NV_OF_OUTPUT_EXECUTE_PARAMS::outputCostBuffer.*/
NV_OF_BUFFER_USAGE_MAX
} NV_OF_BUFFER_USAGE;

/**
* Supported buffer formats
*/
typedef enum _NV_OF_BUFFER_FORMAT
{
NV_OF_BUFFER_FORMAT_UNDEFINED,
NV_OF_BUFFER_FORMAT_GRAYSCALE8, /**< Input buffer format with 8 bit planar format */
NV_OF_BUFFER_FORMAT_NV12, /**< Input buffer format with 8 bit plannar, UV interleaved */
NV_OF_BUFFER_FORMAT_ABGR8, /**< Input buffer format with 8 bit packed A8B8G8R8 */
NV_OF_BUFFER_FORMAT_SHORT, /**< Output or hint buffer format for stereo disparity */
NV_OF_BUFFER_FORMAT_SHORT2, /**< Output or hint buffer format for optical flow vector */
NV_OF_BUFFER_FORMAT_UINT, /**< Cost buffer format for optical flow vector / stereo disparity */
NV_OF_BUFFER_FORMAT_MAX
} NV_OF_BUFFER_FORMAT;

/**
* \struct NV_OF_FLOW_VECTOR
* Struct needed for optical flow. ::NV_OF_EXECUTE_OUTPUT_PARAMS::outputBuffer will be populated with optical flow
* in ::NV_OF_FLOW_VECTOR format for each ::NV_OF_INIT_PARAMS::outGridSize.
* Flow vectors flowx and flowy are 16-bit values with the lowest 5 bits holding fractional value,
* followed by a 10-bit integer value and the most significant bit being a sign bit.
*/
typedef struct _NV_OF_FLOW_VECTOR
{
int16_t flowx; /**< x component of flow in S10.5 format */
int16_t flowy; /**< y component of flow in S10.5 format */
} NV_OF_FLOW_VECTOR;

/**
* \struct NV_OF_STEREO_DISPARITY
* Struct needed for stereo /disparity. ::NV_OF_OUTPUT_EXECUTE_PARAMS::outputBuffer will be populated
* with stereo disparity in ::NV_OF_STEREO_DISPARITY format for each ::NV_OF_INIT_PARAMS::outGridSize.
* Stereo disparity is a 16-bit value with the lowest 5 bits holding fractional value,
* followed by a 11-bit unsigned integer value.
*/
typedef struct _NV_OF_STEREO_DISPARITY
{
uint16_t disparity; /**< Horizontal displacement[in pixels] in 11.5 format. */
} NV_OF_STEREO_DISPARITY;

/**
* \struct NV_OF_INIT_PARAMS
* Optical flow/stereo disparity session initialization parameters.
*/
typedef struct _NV_OF_INIT_PARAMS
{
uint32_t width; /**< [in]: Specifies input buffer width */
uint32_t height; /**< [in]: Specifies input buffer height */
NV_OF_OUTPUT_VECTOR_GRID_SIZE outGridSize; /**< [in]: Specifies flow vector grid size for ::NV_OF_EXECUTE_PARAMS::outputBuffer buffer.*/
NV_OF_HINT_VECTOR_GRID_SIZE hintGridSize; /**< [in]: Specifies flow vector grid size for ::NV_OF_EXECUTE_PARAMS::externalHints buffer.
This field is only considered if ::NV_OF_INIT_PARAMS::enableExternalHints is set */
NV_OF_MODE mode; /**< [in]: Operating mode for NVOF. Set to a value defined by enum ::NV_OF_MODE. */
NV_OF_PERF_LEVEL perfLevel; /**< [in]: Specifies perf level. */
NV_OF_BOOL enableExternalHints; /**< [in]: Set to 1 to enable external hints for optical flow session. */
NV_OF_BOOL enableOutputCost; /**< [in]: Set to 1 to enable output cost calculation for optical flow session. */
NvOFPrivDataHandle hPrivData; /**< [in]: Optical flow private data. It is reserved field and should be set to NULL. */
} NV_OF_INIT_PARAMS;

/**
* \struct NV_OF_BUFFER_DESCRIPTOR
* Creation parameters for optical flow buffers.
*/
typedef struct _NV_OF_BUFFER_DESCRIPTOR
{
uint32_t width; /**< [in]: Buffer width. */
uint32_t height; /**< [in]: Buffer height. */
NV_OF_BUFFER_USAGE bufferUsage; /**< [in]: To specify buffer usage type.
::NV_OF_BUFFER_USAGE_OUTPUT buffer usage type accepts ::NV_OF_CREATE_BUFFER::width,
::NV_OF_BUFFER_DESCRIPTOR::height in ::NV_OF_INIT_PARAMS::outGridSize units.
::NV_OF_BUFFER_USAGE_HINT buffer usage type accepts ::NV_OF_BUFFER_DESCRIPTOR::width,
::NV_OF_BUFFER_DESCRIPTOR::height in ::NV_OF_INIT_PARAMS::hintGridSize units. */
NV_OF_BUFFER_FORMAT bufferFormat; /**< [in]: Buffer format. */

} NV_OF_BUFFER_DESCRIPTOR;

/**
* \struct NV_OF_EXECUTE_INPUT_PARAMS
* Parameters which are sent per frame for optical flow/stereo disparity execution.
*/
typedef struct _NV_OF_EXECUTE_INPUT_PARAMS
{
NvOFGPUBufferHandle inputFrame; /**< [in]: If ::NV_OF_INIT_PARAMS::mode is ::NV_OF_MODE_OPTICALFLOW, this specifies the handle to the buffer containing the input frame.
If ::NV_OF_INIT_PARAMS::mode is ::NV_OF_MODE_STEREODISPARITY, this specifies the handle to the buffer containing the rectified left view. */
NvOFGPUBufferHandle referenceFrame; /**< [in]: If ::NV_OF_INIT_PARAMS::mode is ::NV_OF_MODE_OPTICALFLOW, this specifies the handle to the buffer containing the reference frame.
If ::NV_OF_INIT_PARAMS::mode is ::NV_OF_MODE_STEREODISPARITY, this specifies the handle to the buffer containing the rectified right view. */
NvOFGPUBufferHandle externalHints; /**< [in]: It is an optional input, This field will be considered if client had set ::NV_OF_INIT_PARAMS::enableExternalHint flag.
Client can pass some available predictors as hints.
Optical flow driver will search around those hints to optimize flow vectors quality.
Expected hint buffer format is ::NV_OF_FLOW_VECTOR, ::NV_OF_STEREO_DISPARITY
for ::NV_OF_MODE_OPTICALFLOW, ::NV_OF_MODE_STEREODISPARITY modes respectively for
each ::NV_OF_INIT_PARAMS::hintGridSize in a frame. */
NV_OF_BOOL disableTemporalHints; /**< [in]: To disable temporal hints per optical flow/stereo disparity execution.
Temporal Hints is set by default.
User can choose to disable temporal hints if there is no
dependancy on previous optical flow execution. */
uint32_t padding; /**< [in]: Padding. Must be set to 0. */
NvOFPrivDataHandle hPrivData; /**< [in]: Optical flow private data handle. It is reserved field and should be set to NULL. */
} NV_OF_EXECUTE_INPUT_PARAMS;

/**
* \struct NV_OF_EXECUTE_OUTPUT_PARAMS
* Parameters which are received per frame for optical flow/stereo disparity execution.
*/
typedef struct _NV_OF_EXECUTE_OUTPUT_PARAMS
{
NvOFGPUBufferHandle outputBuffer; /**< [in]: Specifies the pointer to optical flow or stereo disparity buffer handle.
::outputBuffer will be populated with optical flow in
::NV_OF_FLOW_VECTOR format or stereo disparity in
::NV_OF_STEREO_DISPARITY format for each
::NV_OF_VECTOR_GRID_SIZE::outGridSize in a frame.*/
NvOFGPUBufferHandle outputCostBuffer; /**< [in]: Specifies the pointer to output cost calculation buffer handle. */
NvOFPrivDataHandle hPrivData; /**< [in]: Optical flow private data handle. It is reserved field and should be set to NULL. */
} NV_OF_EXECUTE_OUTPUT_PARAMS;

/**
* \brief Initialize NVIDIA Video Optical Flow Interface and validates input params.
*
* Initializes NVIDIA Video Optical Flow Interface and validates input params.
* It also initializes NVIDIA Video Optical Flow driver with the init value passed in ::NV_OF_INIT_PARAMS
* structure.
*
* \param [in] hOf
* Object of ::NvOFHandle type.
* \param [in] initParams
* Pointer to the ::NV_OF_INIT_PARAMS structure.
*
* \return
* ::NV_OF_SUCCESS \n
* ::NV_OF_ERR_INVALID_PTR \n
* ::NV_OF_ERR_UNSUPPORTED_DEVICE \n
* ::NV_OF_ERR_DEVICE_DOES_NOT_EXIST \n
* ::NV_OF_ERR_UNSUPPORTED_PARAM \n
* ::NV_OF_ERR_OUT_OF_MEMORY \n
* ::NV_OF_ERR_INVALID_PARAM \n
* ::NV_OF_ERR_INVALID_VERSION \n
* ::NV_OF_ERR_OF_NOT_INITIALIZED \n
* ::NV_OF_ERR_GENERIC \n
*/
typedef NV_OF_STATUS(NVOFAPI* PFNNVOFINIT) (NvOFHandle hOf, const NV_OF_INIT_PARAMS *initParams);

/**
* \brief Kick off computation of optical flow between input and reference frame.
*
* This is asynchronous function call which kicks off computation of optical flow or stereo disparity
* between ::NV_OF_EXECUTE_INPUT_PARAMS::inputFrame and ::NV_OF_EXECUTE_INPUT_PARAMS::referenceFrame and returns
* after submitting execute paramaters to optical flow engine.
* ::NV_OF_EXECUTE_OUTPUT_PARAMS::outputBuffer will be populated with optical flow or stereo disparity
* based on ::NV_OF_INIT_PARAMS:mode is NV_OF_MODE_OPTICALFLOW or NV_OF_MODE_STEREODISPARITY respectively.
*
* \param [in] hOf
* Object of ::NvOFHandle type.
* \param [in] executeInParams
* pointer to the ::NV_OF_EXECUTE_INPUT_PARAMS structure.
* \param [out] executeOutParams
* pointer to the ::NV_OF_EXECUTE_OUTPUT_PARAMS structure.
*
* \return
* ::NV_OF_SUCCESS \n
* ::NV_OF_ERR_INVALID_PTR \n
* ::NV_OF_ERR_INVALID_DEVICE \n
* ::NV_OF_ERR_DEVICE_DOES_NOT_EXIST \n
* ::NV_OF_ERR_UNSUPPORTED_PARAM \n
* ::NV_OF_ERR_OUT_OF_MEMORY \n
* ::NV_OF_ERR_INVALID_PARAM \n
* ::NV_OF_ERR_INVALID_VERSION \n
* ::NV_OF_ERR_OF_NOT_INITIALIZED \n
* ::NV_OF_ERR_GENERIC \n
*/
typedef NV_OF_STATUS(NVOFAPI* PFNNVOFEXECUTE) (NvOFHandle hOf, const NV_OF_EXECUTE_INPUT_PARAMS *executeInParams, NV_OF_EXECUTE_OUTPUT_PARAMS *executeOutParams);

/**
* \brief Release optical flow API and driver resources.
*
* Releases resources and waits until all resources are gracefully released.
*
* \param [in] hOf
* Object of ::NvOFHandle type.
*
* \return
* ::NV_OF_SUCCESS \n
* ::NV_OF_ERR_INVALID_PTR \n
* ::NV_OF_ERR_DEVICE_DOES_NOT_EXIST \n
* ::NV_OF_ERR_OF_NOT_INITIALIZED \n
* ::NV_OF_ERR_GENERIC \n
*/
typedef NV_OF_STATUS(NVOFAPI* PFNNVOFDESTROY) (NvOFHandle hOf);

/**
* \brief Populate error buffer with the description of last failure.
*
* Populates lastError[] with the description of last failure.
*
* \param [in] hOf
* Object of ::NvOFHandle type.
* \param [in/out] lastError
* lastError is a char array, minimum expected size of lastError[] is MIN_ERROR_STRING_SIZE characters.
* After execution of this function call, lastError[] is populated with error string.
* \param [in/out] As an input parameter, "size" indicates the size of the array provided by the client.
* After execution of this function call, "size" field indicates the number of characters written into
* "lastError" excluding null character.
* \return
* ::NV_OF_SUCCESS \n
* ::NV_OF_ERR_INVALID_PTR \n
* ::NV_OF_ERR_DEVICE_DOES_NOT_EXIST \n
* ::NV_OF_ERR_OF_NOT_INITIALIZED \n
* ::NV_OF_ERR_GENERIC \n
*/
typedef NV_OF_STATUS(NVOFAPI* PFNNVOFGETLASTERROR) (NvOFHandle hOf, char lastError[], uint32_t *size);

/**
* \brief Populate capability array for specified ::NV_OF_CAPS value.
* This is to be called in two stages.
* It returns the number of capability values for specified ::NV_OF_CAPS value when
* queried with "capsVal" set to NULL.
* It populates capsVal array with capability values for specified ::NV_OF_CAPS value
* when queried with "capsVal" set to non-NULL value.
*
* \param [in] hOf
* Object of ::NvOFHandle type.
* \param [in] capsParam
* object of ::NV_OF_CAPS type.
* \param [out] capsVal
* Pointer to uint32_t, minimum expected size of capsVal is the "size" returned by the this function call
* queried with "capsVal" set to NULL.
* \param [out] size
* Pointer to uint32_t, which stores size of populated capsVal.
*
* \return
* ::NV_OF_SUCCESS \n
* ::NV_OF_ERR_INVALID_PTR \n
* ::NV_OF_ERR_DEVICE_DOES_NOT_EXIST \n
* ::NV_OF_ERR_OF_NOT_INITIALIZED \n
* ::NV_OF_ERR_GENERIC \n
*/
typedef NV_OF_STATUS(NVOFAPI* PFNNVOFGETCAPS) (NvOFHandle hOf, NV_OF_CAPS capsParam, uint32_t *capsVal, uint32_t *size);

/**
* \brief Get the largest API version supported by the driver.
*
* This function can be used by clients to determine if the driver supports
* the API header the application was compiled with.
*
* \param [out] version
* Pointer to the requested value. The 4 least significant bits in the returned
* indicate the minor version and the rest of the bits indicate the major
* version of the largest supported version.
*
* \return
* ::NV_OF_SUCCESS \n
* ::NV_OF_ERR_INVALID_PTR \n
*/
NV_OF_STATUS NVOFAPI NvOFGetMaxSupportedApiVersion(uint32_t* version);

#if defined(__cplusplus)
}
#endif /* __cplusplus */

#endif

#endif

+ 258
- 0
src/opr/impl/nvof/nvOpticalFlowCuda.h View File

@@ -0,0 +1,258 @@
/*
* This copyright notice applies to this header file only:
*
* Copyright (c) 2018 NVIDIA Corporation
*
* Permission is hereby granted, free of charge, to any person
* obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without
* restriction, including without limitation the rights to use,
* copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the software, and to permit persons to whom the
* software is furnished to do so, subject to the following
* conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
* OTHER DEALINGS IN THE SOFTWARE.
*/
/**
* \file NvOpticalFlowCuda.h
* NVIDIA GPUs - Turing and above contains a hardware-based optical flow engine
* which provides fully-accelerated hardware-based optical flow and stereo estimation.
* nvOpticalFlowCuda.h provides cuda specific enums, structure definitions and function pointers prototypes.
* \date 2018
* This file contains CUDA specific enums, structure definitions and function prototypes.
*/

/**
* \file src/opr/impl/nvof/nvOpticalFlowCuda.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"

#if MGB_CUDA
#ifndef _NV_OPTICALFLOW_CUDA_H_
#define _NV_OPTICALFLOW_CUDA_H_
#include "nvOpticalFlowCommon.h"
#include <cuda.h>
#define MAX_NUM_PLANES 3

#if defined(__cplusplus)

extern "C"
{
#endif /* __cplusplus */

/**
* Supported CUDA buffer types.
*/
typedef enum _NV_OF_CUDA_BUFFER_TYPE
{
NV_OF_CUDA_BUFFER_TYPE_UNDEFINED,
NV_OF_CUDA_BUFFER_TYPE_CUARRAY, /**< Buffer type is CUarray */
NV_OF_CUDA_BUFFER_TYPE_CUDEVICEPTR, /**< Buffer type is CUdeviceptr */
NV_OF_CUDA_BUFFER_TYPE_MAX
} NV_OF_CUDA_BUFFER_TYPE;

/**
* \struct NV_BUFFER_STRIDE
* Horizontal and vertical strides of a plane.
*/
typedef struct _NV_OF_BUFFER_STRIDE
{
uint32_t strideXInBytes; /**< Horizontal stride. */
uint32_t strideYInBytes; /**< Vertical stride. */
} NV_OF_BUFFER_STRIDE;

/**
* \struct NV_OF_CUDA_BUFFER_STRIDE_INFO
* This structure stores buffer stride information which is populated in the ::nvOFGPUBufferGetStrideInfo() API.
*/
typedef struct _NV_OF_CUDA_BUFFER_STRIDE_INFO
{
NV_OF_BUFFER_STRIDE strideInfo[MAX_NUM_PLANES]; /**< Stride information of each plane.*/
uint32_t numPlanes; /**< Number of planes. */
} NV_OF_CUDA_BUFFER_STRIDE_INFO;

/**
* \brief Create an instance of NvOFHandle object.
*
* This function creates an instance of NvOFHandle object and returns status.
* Client is expected to release NvOFHandle resource using Destroy function call.
*
* \param [in] cuContext
* Should be set to cuda context created by Client.
* \param [out] NvOFHandle*
* Pointer of class ::NvOFHandle object.
*
* \return
* ::NV_OF_SUCCESS \n
* ::NV_OF_ERR_OUT_OF_MEMORY \n
* ::NV_OF_ERR_INVALID_VERSION \n
* ::NV_OF_ERR_UNSUPPORTED_PARAM \n
*/
typedef NV_OF_STATUS(NVOFAPI* PFNNVCREATEOPTICALFLOWCUDA) (CUcontext device, NvOFHandle *hOf);

/**
* \brief Set input and output cuda stream for specified optical flow instance.
*
* Optical flow algorithm may optionally involve cuda preprocessing on the input buffers and post
* processing on the output flow vectors. This function is used to set input and output cuda stream
* to pipeline and synchronize the cuda preprocessing and post processing tasks with OF HW engine.
* Client should call this function before Execute function to update input and/or output streams otherwise
* Execute function will either use preset input, output streams or default streams(If streams are never set before).
*
* \param [in] hOf
* Object of ::NvOFHandle type.
* \param [in] inputStream
* CUstream type object which is used to process ::NV_OF_EXECUTE_PARAMS::inputFrame,
* ::NV_OF_EXECUTE_PARAMS::referenceFrame and optional NV_OF_EXECUTE_PARAMS::externalHints.
* \param [in] outputStream
* CUstream type object which is used to process ::NV_OF_EXECUTE_PARAMS::outputBuffer and
* optional NV_OF_EXECUTE_PARAMS::costBuffer.
*
* \return
* ::NV_OF_SUCCESS \n
* ::NV_OF_ERR_INVALID_PTR \n
* ::NV_OF_ERR_INVALID_DEVICE \n
* ::NV_OF_ERR_DEVICE_DOES_NOT_EXIST \n
* ::NV_OF_ERR_UNSUPPORTED_PARAM \n
* ::NV_OF_ERR_OUT_OF_MEMORY \n
* ::NV_OF_ERR_INVALID_PARAM \n
* ::NV_OF_ERR_INVALID_VERSION \n
* ::NV_OF_ERR_OF_NOT_INITIALIZED \n
* ::NV_OF_ERR_GENERIC \n
*/
typedef NV_OF_STATUS(NVOFAPI* PFNNVOFSETIOCUDASTREAMS) (NvOFHandle hOf, CUstream inputStream, CUstream outputStream);

/**
* \brief Create ::NvOFGPUBufferHandle resource.
*
* This function creates ::NvOFGPUBufferHandle resource for specified cuda bufferType.
*
* \param [in] hOf
* Pointer to the NvOFHandle.
* \param [in] createBufferParams
* pointer of the ::NV_OF_CREATE_BUFFER.
* \param [out] ofGpuBuffer
* Output pointer of ::NvOFGPUBufferHandle type.
*
* \return
* ::NV_OF_SUCCESS \n
* ::NV_OF_ERR_INVALID_PTR \n
* ::NV_OF_ERR_DEVICE_DOES_NOT_EXIST \n
* ::NV_OF_ERR_OUT_OF_MEMORY \n
* ::NV_OF_ERR_INVALID_PARAM \n
* ::NV_OF_ERR_GENERIC \n
*/
typedef NV_OF_STATUS(NVOFAPI* PFNNVOFCREATEGPUBUFFERCUDA) (NvOFHandle hOf, const NV_OF_BUFFER_DESCRIPTOR *bufferDesc,
NV_OF_CUDA_BUFFER_TYPE bufferType, NvOFGPUBufferHandle *hOfGpuBuffer);

/**
* \brief Return CUarray object associated with ::NvOFGPUBufferHandle type resource.
*
* \param [in] ofGpuBuffer
* Object of type NvOFGPUBufferHandle, created by a call to NvOFCreateGPUBufferCuda() with bufferType set to ::NV_OF_CUDA_BUFFER_TYPE_CUARRAY.
*
* \return
* Object of CUarray type.
* If ofGpubuffer corresponds to a GPU buffer that was not created with buffer type NV_OF_CUDA_BUFFER_TYPE_CUARRAY,
* this function returns NULL
*/
typedef CUarray(NVOFAPI* PFNNVOFGPUBUFFERGETCUARRAY) (NvOFGPUBufferHandle ofGpuBuffer);

/**
* \brief Return CUdeviceptr object associated with ::NvOFGPUBufferHandle type resource.
*
* \param [in] ofGpuBuffer
* Object of type NvOFGPUBufferHandle, created by a call to NvOFCreateGPUBufferCuda() with bufferType set to ::NV_OF_CUDA_BUFFER_TYPE_CUDEVICEPTR.
*
* \return
* Object of the CUdeviceptr type.
* If ofGpubuffer corresponds to a GPU buffer that was not created with buffer type NV_OF_CUDA_BUFFER_TYPE_CUDEVICEPTR,
* this function returns 0
*/
typedef CUdeviceptr(NVOFAPI* PFNNVOFGPUBUFFERGETCUDEVICEPTR) (NvOFGPUBufferHandle ofGpuBuffer);

/**
* \brief Populates buffer information associated with ::NvOFGPUBufferHandle type resource.
*
* Populates structure ::NV_OF_CUDA_BUFFER_STRIDE_INFO with the horizontal and vertical stride details of all the planes.
* \param [in] ofGpuBuffer
* Object of type NvOFGPUBufferHandle, created by a call to NvOFCreateGPUBufferCuda().
* \param [out] strideInfo
* pointer to the ::NV_OF_CUDA_BUFFER_STRIDE_INFO.
*
* \return
* ::NV_OF_SUCCESS \n
* ::NV_OF_ERR_INVALID_PTR \n
*/
typedef NV_OF_STATUS(NVOFAPI* PFNVOFGPUBUFFERGETSTRIDEINFO) (NvOFGPUBufferHandle ofGpuBuffer, NV_OF_CUDA_BUFFER_STRIDE_INFO *strideInfo);

/**
* \brief Destroy NvOFGPUBufferHandle object and associated resources.
*
*
* \return
* ::NV_OF_SUCCESS \n
* ::NV_OF_ERR_GENERIC \n
*/
typedef NV_OF_STATUS(NVOFAPI* PFNNVOFDESTROYGPUBUFFERCUDA) (NvOFGPUBufferHandle buffer);

/**
* \struct NV_OF_CUDA_API_FUNCTION_LIST
* This is structure of function pointers which are populated by ::NvOFAPICreateInstanceCuda() API.
* Defination of each cuda specific function pointer is defined above.
*/
typedef struct _NV_OF_CUDA_API_FUNCTION_LIST
{
PFNNVCREATEOPTICALFLOWCUDA nvCreateOpticalFlowCuda;
PFNNVOFINIT nvOFInit;
PFNNVOFCREATEGPUBUFFERCUDA nvOFCreateGPUBufferCuda;
PFNNVOFGPUBUFFERGETCUARRAY nvOFGPUBufferGetCUarray;
PFNNVOFGPUBUFFERGETCUDEVICEPTR nvOFGPUBufferGetCUdeviceptr;
PFNVOFGPUBUFFERGETSTRIDEINFO nvOFGPUBufferGetStrideInfo;
PFNNVOFSETIOCUDASTREAMS nvOFSetIOCudaStreams;
PFNNVOFEXECUTE nvOFExecute;
PFNNVOFDESTROYGPUBUFFERCUDA nvOFDestroyGPUBufferCuda;
PFNNVOFDESTROY nvOFDestroy;
PFNNVOFGETLASTERROR nvOFGetLastError;
PFNNVOFGETCAPS nvOFGetCaps;
} NV_OF_CUDA_API_FUNCTION_LIST;

/**
* \brief ::NvOFAPICreateInstanceCuda() API is the entry point to the NvOFAPI interface.
*
* ::NvOFAPICreateInstanceCuda() API populates functionList with function pointers to the API routines implemented by the
* NvOFAPI interface.
*
* \return
* ::NV_OF_SUCCESS \n
* ::NV_OF_ERR_INVALID_VERSION \n
* :: NV_OF_ERR_INVALID_PTR \n
*/
NV_OF_STATUS NVOFAPI NvOFAPICreateInstanceCuda(uint32_t apiVer, NV_OF_CUDA_API_FUNCTION_LIST *functionList);
#if defined(__cplusplus)
}
#endif /* __cplusplus */

#endif

#endif

+ 38
- 0
src/opr/include/megbrain/opr/misc.h View File

@@ -13,6 +13,10 @@

#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#if MGB_CUDA
#include "../../../impl/nvof/denseflownvidia.h"
#include "megbrain/opr/param_defs.h"
#endif
#include "megdnn/oprs.h"

#include <array>
@@ -94,6 +98,40 @@ MGB_DEFINE_OPR_CLASS(Cumsum, cg::SingleCNOperatorNodeBaseT<
void init_output_static_infer_desc() override;
};

#if MGB_CUDA
MGB_DEFINE_OPR_CLASS(NvOf, cg::SingleCNOperatorNodeBase) // {

public:
using Param = megdnn::param::NvOf;
NvOf(VarNode* src, const Param& param,
const OperatorNodeConfig& config);

// for serialization
static SymbolVar make(SymbolVar opr, const Param& param,
const OperatorNodeConfig& config = {});

static SymbolVar make(SymbolVar opr,
const OperatorNodeConfig& config = {}) {
return make(opr, {}, config);
}

Param param() const {
return m_param;
}

protected:
void init_output_dtype() override;
void scn_do_execute() override;
void init_output_static_infer_desc() override;

private:
std::shared_ptr<NVFlowExtractor> nv_flow_extractor;
std::vector<size_t> vshape;
Param m_param;
std::mutex m_lock;
bool init_flag = false;
};
#endif

namespace intl {
using CondTakeBase =


+ 1
- 2
src/serialization/impl/schema.fbs View File

@@ -28,7 +28,6 @@ table Blob {
}

table Reserved0 {}
table Reserved1 {}

union OperatorParam {
param.Empty = 1,
@@ -101,7 +100,7 @@ union OperatorParam {
param.Remap = 68,
param.NMSKeep = 69,
param.AdaptivePooling = 70,
Reserved1 = 71,
param.NvOf = 71,
}

table Operator {


+ 1
- 0
tools/param_defs/mgb_opr_param_defs.py View File

@@ -144,3 +144,4 @@ pdef('PersistentOutputStorage').add_fields(
)
)

(pdef('NvOf', 'opr Implements NVIDIA Optical Flow SDK.').add_fields('uint32', 'precision', 1))

Loading…
Cancel
Save