TensorRT – 扩展TensorRT C++API的模型输入维度,增加Dims5,Dims6,Dims7,Dims8
1 TensorRT C++ API支持的模型输入维度
在TensorRT 7.0及以上版本,我们通常使用以下语句指定输入维度:
const std::string input_name = "input";
const std::string output_name = "output";
const int inputIndex = m_TensorRT_Engine->getBindingIndex(input_name.c_str());
const int outputIndex = m_TensorRT_Engine->getBindingIndex(output_name.c_str());
m_TensorRT_Context->setBindingDimensions(inputIndex, Dims3(3, 100, 20));
其中Dims3代表该深度学习模型的输入Tensor的维度为三维tensor,shape为(3,100,20)
一般的深度学习模型,一般的输入维度为(C,H,W),这种输入的维度数据为三维tensor。
另外TensorRT C++ API最高支持Dims4,用于支持4维tensor数据的模型输入。但是随着深度学习框架目前发展的越来越复杂,更多的深度的学习模型需要5维,6维甚至更高维度的tensor作为网络输入,那么如何在现有的TensorRT API去扩展更高维度的输入tensor以满足我们自己的需要呢?
2 扩展TensorRT C++ API 模型输入维度
在TensorRT C++ API的include目录下的NvInferRuntimeCommon.h文件定义了类Class Dims32,
//!
//! \class Dims
//! \brief Structure to define the dimensions of a tensor.
//!
//! TensorRT can also return an invalid dims structure. This structure is represented by nbDims == -1
//! and d[i] == 0 for all d.
//!
//! TensorRT can also return an "unknown rank" dims structure. This structure is represented by nbDims == -1
//! and d[i] == -1 for all d.
//!
class Dims32
{
public:
//! The maximum rank (number of dimensions) supported for a tensor.
static constexpr int32_t MAX_DIMS{8};
//! The rank (number of dimensions).
int32_t nbDims;
//! The extent of each dimension.
int32_t d[MAX_DIMS];
};
该类用于定义tensor的输入维度,从类定义上看,该类支持的最大维度为8。
在TensorRT C++ API的include目录下的NvInferLegacyDims.h定义了目前TensorRT所指的输入维度:
/*
* Copyright 1993-2021 NVIDIA Corporation. All rights reserved.
*
* NOTICE TO LICENSEE:
*
* This source code and/or documentation ("Licensed Deliverables") are
* subject to NVIDIA intellectual property rights under U.S. and
* international Copyright laws.
*
* These Licensed Deliverables contained herein is PROPRIETARY and
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
* conditions of a form of NVIDIA software license agreement by and
* between NVIDIA and Licensee ("License Agreement") or electronically
* accepted by Licensee. Notwithstanding any terms or conditions to
* the contrary in the License Agreement, reproduction or disclosure
* of the Licensed Deliverables to any third party without the express
* written consent of NVIDIA is prohibited.
*
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
* OF THESE LICENSED DELIVERABLES.
*
* U.S. Government End Users. These Licensed Deliverables are a
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
* 1995), consisting of "commercial computer software" and "commercial
* computer software documentation" as such terms are used in 48
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
* U.S. Government End Users acquire the Licensed Deliverables with
* only those rights set forth herein.
*
* Any use of the Licensed Deliverables in individual and commercial
* software must include, in the user documentation and internal
* comments to the code, the above Disclaimer and U.S. Government End
* Users Notice.
*/
#ifndef NV_INFER_LEGACY_DIMS_H
#define NV_INFER_LEGACY_DIMS_H
#include "NvInferRuntimeCommon.h"
//!
//! \file NvInferLegacyDims.h
//!
//! This file contains declarations of legacy dimensions types which use channel
//! semantics in their names, and declarations on which those types rely.
//!
//!
//! \namespace nvinfer1
//!
//! \brief The TensorRT API version 1 namespace.
//!
namespace nvinfer1
{
//!
//! \class Dims2
//! \brief Descriptor for two-dimensional data.
//!
class Dims2 : public Dims
{
public:
//!
//! \brief Construct an empty Dims2 object.
//!
Dims2()
: Dims{2, {}}
{
}
//!
//! \brief Construct a Dims2 from 2 elements.
//!
//! \param d0 The first element.
//! \param d1 The second element.
//!
Dims2(int32_t d0, int32_t d1)
: Dims{2, {d0, d1}}
{
}
};
//!
//! \class DimsHW
//! \brief Descriptor for two-dimensional spatial data.
//!
class DimsHW : public Dims2
{
public:
//!
//! \brief Construct an empty DimsHW object.
//!
DimsHW()
: Dims2()
{
}
//!
//! \brief Construct a DimsHW given height and width.
//!
//! \param height the height of the data
//! \param width the width of the data
//!
DimsHW(int32_t height, int32_t width)
: Dims2(height, width)
{
}
//!
//! \brief Get the height.
//!
//! \return The height.
//!
int32_t& h()
{
return d[0];
}
//!
//! \brief Get the height.
//!
//! \return The height.
//!
int32_t h() const
{
return d[0];
}
//!
//! \brief Get the width.
//!
//! \return The width.
//!
int32_t& w()
{
return d[1];
}
//!
//! \brief Get the width.
//!
//! \return The width.
//!
int32_t w() const
{
return d[1];
}
};
//!
//! \class Dims3
//! \brief Descriptor for three-dimensional data.
//!
class Dims3 : public Dims
{
public:
//!
//! \brief Construct an empty Dims3 object.
//!
Dims3()
: Dims{3, {}}
{
}
//!
//! \brief Construct a Dims3 from 3 elements.
//!
//! \param d0 The first element.
//! \param d1 The second element.
//! \param d2 The third element.
//!
Dims3(int32_t d0, int32_t d1, int32_t d2)
: Dims{3, {d0, d1, d2}}
{
}
};
//!
//! \class Dims4
//! \brief Descriptor for four-dimensional data.
//!
class Dims4 : public Dims
{
public:
//!
//! \brief Construct an empty Dims4 object.
//!
Dims4()
: Dims{4, {}}
{
}
//!
//! \brief Construct a Dims4 from 4 elements.
//!
//! \param d0 The first element.
//! \param d1 The second element.
//! \param d2 The third element.
//! \param d3 The fourth element.
//!
Dims4(int32_t d0, int32_t d1, int32_t d2, int32_t d3)
: Dims{4, {d0, d1, d2, d3}}
{
}
};
} // namespace nvinfer1
#endif // NV_INFER_LEGCY_DIMS_H
从上述文件的代码看,构建输入维度只需要继承类Dims,然后按定义进行初始化即可。所以为了TensortRT可以支持Dims5,Dims6,Dims7,Dims8等高输入维度,那么需要自定义扩展以上维度,扩展后的NvInferLegacyDims.h文件内容如下所示:
/*
* Copyright 1993-2021 NVIDIA Corporation. All rights reserved.
*
* NOTICE TO LICENSEE:
*
* This source code and/or documentation ("Licensed Deliverables") are
* subject to NVIDIA intellectual property rights under U.S. and
* international Copyright laws.
*
* These Licensed Deliverables contained herein is PROPRIETARY and
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
* conditions of a form of NVIDIA software license agreement by and
* between NVIDIA and Licensee ("License Agreement") or electronically
* accepted by Licensee. Notwithstanding any terms or conditions to
* the contrary in the License Agreement, reproduction or disclosure
* of the Licensed Deliverables to any third party without the express
* written consent of NVIDIA is prohibited.
*
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
* OF THESE LICENSED DELIVERABLES.
*
* U.S. Government End Users. These Licensed Deliverables are a
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
* 1995), consisting of "commercial computer software" and "commercial
* computer software documentation" as such terms are used in 48
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
* U.S. Government End Users acquire the Licensed Deliverables with
* only those rights set forth herein.
*
* Any use of the Licensed Deliverables in individual and commercial
* software must include, in the user documentation and internal
* comments to the code, the above Disclaimer and U.S. Government End
* Users Notice.
*/
#ifndef NV_INFER_LEGACY_DIMS_H
#define NV_INFER_LEGACY_DIMS_H
#include "NvInferRuntimeCommon.h"
//!
//! \file NvInferLegacyDims.h
//!
//! This file contains declarations of legacy dimensions types which use channel
//! semantics in their names, and declarations on which those types rely.
//!
//!
//! \namespace nvinfer1
//!
//! \brief The TensorRT API version 1 namespace.
//!
namespace nvinfer1
{
//!
//! \class Dims2
//! \brief Descriptor for two-dimensional data.
//!
class Dims2 : public Dims
{
public:
//!
//! \brief Construct an empty Dims2 object.
//!
Dims2()
: Dims{2, {}}
{
}
//!
//! \brief Construct a Dims2 from 2 elements.
//!
//! \param d0 The first element.
//! \param d1 The second element.
//!
Dims2(int32_t d0, int32_t d1)
: Dims{2, {d0, d1}}
{
}
};
//!
//! \class DimsHW
//! \brief Descriptor for two-dimensional spatial data.
//!
class DimsHW : public Dims2
{
public:
//!
//! \brief Construct an empty DimsHW object.
//!
DimsHW()
: Dims2()
{
}
//!
//! \brief Construct a DimsHW given height and width.
//!
//! \param height the height of the data
//! \param width the width of the data
//!
DimsHW(int32_t height, int32_t width)
: Dims2(height, width)
{
}
//!
//! \brief Get the height.
//!
//! \return The height.
//!
int32_t& h()
{
return d[0];
}
//!
//! \brief Get the height.
//!
//! \return The height.
//!
int32_t h() const
{
return d[0];
}
//!
//! \brief Get the width.
//!
//! \return The width.
//!
int32_t& w()
{
return d[1];
}
//!
//! \brief Get the width.
//!
//! \return The width.
//!
int32_t w() const
{
return d[1];
}
};
//!
//! \class Dims3
//! \brief Descriptor for three-dimensional data.
//!
class Dims3 : public Dims
{
public:
//!
//! \brief Construct an empty Dims3 object.
//!
Dims3()
: Dims{3, {}}
{
}
//!
//! \brief Construct a Dims3 from 3 elements.
//!
//! \param d0 The first element.
//! \param d1 The second element.
//! \param d2 The third element.
//!
Dims3(int32_t d0, int32_t d1, int32_t d2)
: Dims{3, {d0, d1, d2}}
{
}
};
//!
//! \class Dims4
//! \brief Descriptor for four-dimensional data.
//!
class Dims4 : public Dims
{
public:
//!
//! \brief Construct an empty Dims4 object.
//!
Dims4()
: Dims{4, {}}
{
}
//!
//! \brief Construct a Dims4 from 4 elements.
//!
//! \param d0 The first element.
//! \param d1 The second element.
//! \param d2 The third element.
//! \param d3 The fourth element.
//!
Dims4(int32_t d0, int32_t d1, int32_t d2, int32_t d3)
: Dims{4, {d0, d1, d2, d3}}
{
}
};
//!
//! \class Dims5
//! \brief Descriptor for four-dimensional data.
//!
class Dims5 : public Dims
{
public:
//!
//! \brief Construct an empty Dims5 object.
//!
Dims5()
{
nbDims = 5;
for (int32_t i = 0; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
//!
//! \brief Construct a Dims5 from 5 elements.
//!
//! \param d0 The first element.
//! \param d1 The second element.
//! \param d2 The third element.
//! \param d3 The fourth element.
//! \param d4 The fifth element.
//!
Dims5(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4)
{
nbDims = 5;
d[0] = d0;
d[1] = d1;
d[2] = d2;
d[3] = d3;
d[4] = d4;
for (int32_t i = nbDims; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
};
//!
//! \class Dims6
//! \brief Descriptor for four-dimensional data.
//!
class Dims6 : public Dims
{
public:
//!
//! \brief Construct an empty Dims5 object.
//!
Dims6()
{
nbDims = 6;
for (int32_t i = 0; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
//!
//! \brief Construct a Dims5 from 5 elements.
//!
//! \param d0 The first element.
//! \param d1 The second element.
//! \param d2 The third element.
//! \param d3 The fourth element.
//! \param d4 The fifth element.
//! \param d5 The sixth element.
//!
Dims6(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4, int32_t d5)
{
nbDims = 6;
d[0] = d0;
d[1] = d1;
d[2] = d2;
d[3] = d3;
d[4] = d4;
d[5] = d5;
for (int32_t i = nbDims; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
};
//!
//! \class Dims7
//! \brief Descriptor for four-dimensional data.
//!
class Dims7 : public Dims
{
public:
//!
//! \brief Construct an empty Dims5 object.
//!
Dims7()
{
nbDims = 7;
for (int32_t i = 0; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
//!
//! \brief Construct a Dims5 from 5 elements.
//!
//! \param d0 The first element.
//! \param d1 The second element.
//! \param d2 The third element.
//! \param d3 The fourth element.
//! \param d4 The fifth element.
//! \param d5 The sixth element.
//! \param d6 The seventh element.
//!
Dims7(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4, int32_t d5, int32_t d6)
{
nbDims = 7;
d[0] = d0;
d[1] = d1;
d[2] = d2;
d[3] = d3;
d[4] = d4;
d[5] = d5;
d[6] = d6;
for (int32_t i = nbDims; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
};
//!
//! \class Dims8
//! \brief Descriptor for four-dimensional data.
//!
class Dims8 : public Dims
{
public:
//!
//! \brief Construct an empty Dims5 object.
//!
Dims8()
{
nbDims = 8;
for (int32_t i = 0; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
//!
//! \brief Construct a Dims5 from 5 elements.
//!
//! \param d0 The first element.
//! \param d1 The second element.
//! \param d2 The third element.
//! \param d3 The fourth element.
//! \param d4 The fifth element.
//! \param d5 The sixth element.
//! \param d6 The seventh element.
//! \param d7 The eighth element.
//!
Dims8(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4, int32_t d5, int32_t d6, int32_t d7)
{
nbDims = 8;
d[0] = d0;
d[1] = d1;
d[2] = d2;
d[3] = d3;
d[4] = d4;
d[5] = d5;
d[6] = d6;
d[7] = d7;
for (int32_t i = nbDims; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
};
} // namespace nvinfer1
#endif // NV_INFER_LEGCY_DIMS_H
将NvInferLegacyDims.h修改之后,重新编译即可使用所扩展的Dims5、Dims6、Dims7、Dims8的5维,6维,7维,8维网络输入维度。
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:TensorRT – 扩展TensorRT C++API的模型输入维度,增加Dims5,Dims6,Dims7,Dims8
原文链接:https://www.stubbornhuang.com/1761/
发布于:2021年10月19日 14:03:32
修改于:2023年06月26日 21:10:20
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
50