本文首发于 MXNet与Gluon的黑科技工具箱写文章
作者:梁德澎
前言
本文主要内容是解读MXNet是加载模型参数文件并解析得到NDArray所涉及到的代码,希望读者读完本文能对MXNet参数文件的格式有清晰的了解,并可以自己来实现参数文件的解析。
解析MXNet参数文件C++小工程:
https://github.com/Ldpe2G/DeepLearningForFun/tree/master/MXNet-Cpp/parsingNDArraygithub.com
本文解读的MXNet代码基于的版本:commit 7d2c9bf3b631433132452760734b684e39170814
Python前端代码入口
首先从MXNet Python前端看是如何是调用底层C接口来读取NDArray参数文件的,这部分代码见源码 ${MXNET_ROOT}/python/mxnet/ndarray/utils.py
第149行:
def load(fname):
if not isinstance(fname, string_types):
raise TypeError('fname required to be a string')
out_size = mx_uint()
out_name_size = mx_uint()
handles = ctypes.POINTER(NDArrayHandle)()
names = ctypes.POINTER(ctypes.c_char_p)()
check_call(_LIB.MXNDArrayLoad(c_str(fname),
ctypes.byref(out_size),
ctypes.byref(handles),
ctypes.byref(out_name_size),
ctypes.byref(names)))
.....
这个 load 函数接收参数路径作为输入,主要是调用 MXNDArrayLoad
这个底层的C函数接口来读取参数。MXNet底层实现是C++,并提供了一层C的接口供前端语言去调用。
C接口层
接着来看下MXNDArrayLoad
接口的实现,这部分代码见${MXNET_ROOT}/src/c_api/c_api.cc
第1344行:
int MXNDArrayLoad(const char* fname,
uint32_t *out_size,
NDArrayHandle** out_arr,
uint32_t *out_name_size,
const char*** out_names) {
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
ret->ret_vec_str.clear();
API_BEGIN();
std::vector<NDArray> data;
std::vector<std::string> &names = ret->ret_vec_str;
{
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
mxnet::NDArray::Load(fi.get(), &data, &names);
}
......
}
核心代码就是首先打开文件流,接着调用NDArray类的静态函数mxnet::NDArray::Load
函数读取并解析参数文件,得到参数NDArray
数组保存到data
这个变量里面。我们只要关注dmlc::Stream
这个类的实现还有mxnet::NDArray::Load
这个类的实现就可以了。
底层C++实现
NDArray::Load静态函数
函数具体实现见${MXNET_ROOT}/src/ndarray/ndarray.cc
第 1924 行:
void NDArray::Load(dmlc::Stream* fi,
std::vector<NDArray>* data,
std::vector<std::string>* keys) {
uint64_t header, reserved;
CHECK(fi->Read(&header))
<< "Invalid NDArray file format";
CHECK(fi->Read(&reserved))
<< "Invalid NDArray file format";
CHECK(header == kMXAPINDArrayListMagic)
<< "Invalid NDArray file format";
CHECK(fi->Read(data))
<< "Invalid NDArray file format";
CHECK(fi->Read(keys))
<< "Invalid NDArray file format";
CHECK(keys->size() == 0 || keys->size() == data->size())
<< "Invalid NDArray file format";
}
从这里读取内容的过程可以大概看出NDArray参数文件存储的内容格式。
首先文件开头保存了两个uint64\_t类型的数字,接着就是NDArray参数数组,接着是每个NDArray对应的名字数组。
读取的时候都是调用Strem类的Read
函数,接下来就是看下Stream类的实现。
Stream类
看回上面打开参数文件的代码:
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
dmlc::Stream::Create
代码见dmlc-core
子模块:${MXNET_ROOT}/3rdparty/dmlc-core/src/io.cc
第132行:
Stream *Stream::Create(const char *uri,
const char * const flag,
bool try_create) {
io::URI path(uri);
return io::FileSystem::
GetInstance(path)->Open(path, flag, try_create);
}
调用了FileSystem::GetINstance
函数得到实例,并调用Open
函数打开文件,这里返回的实例是LocalFileSystem
类的实例,其Open
函数见${MXNET_ROOT}/3rdparty/dmlc-core/src/io/local_filesys.cc
第147行:
SeekStream *LocalFileSystem::Open(const URI &path,
const char* const mode,
bool allow_null) {
FILE *fp = NULL;
const char *fname = path.name.c_str();
using namespace std;
std::string flag = mode;
if (flag == "r") flag = "rb";
fp = fopen(fname, flag.c_str());
if (fp != NULL) {
return new FileStream(fp, false);
} else {
return NULL;
}
}
为了可读性我简化了代码,可以看到就是调用std::fopen
函数打开文件,并把FILE
指针传给FileStream
类,
代码见${MXNET_ROOT}/3rdparty/dmlc-core/src/io/local_filesys.cc
第27行:
class FileStream : public SeekStream {
public:
explicit FileStream(FILE *fp, bool use_stdio)
: fp_(fp), use_stdio_(use_stdio) {}
virtual ~FileStream(void) {
this->Close();
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, 1, size, fp_);
}
......
private:
std::FILE *fp_;
bool use_stdio_;
};
可以看到FileStream
继承自SeekStrem
,而且成员函数Read
实现的功能是调用std::fread
函数从fp_
文件指针里面读取size大小字节的内容,std::fread
的文档见https://en.cppreference.com/w/cpp/io/c/fread
:
看下每个参数的解释:
buffer - void 指针指向从文件留中读取内容的存取内存
size - 指针指向内存每个元素字节大小,这里由于是void指针,所以size大小恒为1
count - 读取元素个数
stream - 文件流
MXNet这里的实现是把需要被读取的内存指针转换成void *
,这样子就可以兼容各种基本类型的指针读取,只需要记住传入的读取元素个数是 sizeof(T) * count
,就是原来类型元素个数乘以每个元素对应的字节数。
接着再看回上面打开参数文件的代码:
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
Strem::Create
返回的是Strem
类型,而不是SeekStrem
,所以继续往上找Strem
类的定义,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/io.h
第30行:
class Stream { // NOLINT(*)
public:
virtual size_t Read(void *ptr, size_t size) = 0;
static Stream *Create(const char *uri,
const char* const flag,
bool allow_null = false);
template<typename T>
inline bool Read(T *out_data);
template<typename T>
inline bool ReadArray(T* data, size_t num_elems);
};
为了可读性,只保留了读文件相关的代码,可以看到,FileStream
重写了virtual size_t Read(void *ptr, size_t size) = 0;
虚函数,
而回看NDArray
静态Load
函数:
void NDArray::Load(dmlc::Stream* fi,
std::vector<NDArray>* data,
std::vector<std::string>* keys) {
uint64_t header, reserved;
CHECK(fi->Read(&header))
<< "Invalid NDArray file format";
CHECK(fi->Read(&reserved))
<< "Invalid NDArray file format";
CHECK(header == kMXAPINDArrayListMagic)
<< "Invalid NDArray file format";
CHECK(fi->Read(data))
<< "Invalid NDArray file format";
CHECK(fi->Read(keys))
<< "Invalid NDArray file format";
CHECK(keys->size() == 0 || keys->size() == data->size())
<< "Invalid NDArray file format";
}
具体读参数文件内容的时候调用的是Stream
类的Read(T *out_data)
模板函数
template<typename T>
inline bool Read(T *out_data);
这个模板函数的实现很有意思,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/io.h
第455:
template<typename T>
inline bool Stream::Read(T *out_data) {
return serializer::Handler<T>::Read(this, out_data);
}
可以看到调用了Handler::Read
函数,继续跟进去看Handler
的实现,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/serializer.h
第258行:
template<typename T>
struct Handler {
......
/*!
* \brief read data to stream
* \param strm the stream to read the data.
* \param data the pointer to the data obeject to read
* \return whether the read is successful
*/
inline static bool Read(Stream *strm, T *data) {
return
IfThenElse<dmlc::is_arithmetic<T>::value,
ArithmeticHandler<T>,
IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
NativePODHandler<T>,
IfThenElse<dmlc::has_saveload<T>::value,
SaveLoadClassHandler<T>,
UndefinedSerializerFor<T>, T>,
T>,
T>
::Read(strm, data);
}
};
一开始看到这串代码可能会有点懵,不过没关系,接下来我们就一步步拆解这段代码,首先看IfThenElse
结构体的实现,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/serializer.h
第48行:
template<bool cond, typename Then, typename Else, typename Return>
struct IfThenElse;
template<typename Then, typename Else, typename T>
struct IfThenElse<true, Then, Else, T> {
......
inline static bool Read(Stream *strm, T *data) {
return Then::Read(strm, data);
}
};
template<typename Then, typename Else, typename T>
struct IfThenElse<false, Then, Else, T> {
......
inline static bool Read(Stream *strm, T *data) {
return Else::Read(strm, data);
}
};
我理解就是根据模板参数在编译期间做分支选择,根据模板参数决定调用实现分支,这里可以看到如果第一个模板参数template<bool cond, ...>
为true
的话,就调用 Then::Read
函数 ,否则调用Else::Read
函数,然后看回Handler::Read
函数:
inline static bool Read(Stream *strm, T *data) {
return
IfThenElse<dmlc::is_arithmetic<T>::value,
ArithmeticHandler<T>,
IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
NativePODHandler<T>,
IfThenElse<dmlc::has_saveload<T>::value,
SaveLoadClassHandler<T>,
UndefinedSerializerFor<T>, T>,
T>,
T>
::Read(strm, data);
}
};
代码就很好理解了,如果dmlc::is_arithmetic<T>::value
值为true
则走ArithmeticHandler<T>
,否则再进行第二次判断,先来看下具体实现${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/type_traits.h
第66行:
template<typename T>
struct is_arithmetic {
static const bool value = std::is_arithmetic<T>::value;
};
为了可读性我简化了下代码,只要查下C++的文档看下std::is_arithmetic<T>
的定义就知道什么模板参数类型是什么的情况下值是true或者false,C++文档解释见https://en.cppreference.com/w/cpp/types/is_arithmetic
:
也就是如果模板类型T
是整数型或者浮点型,value
的值就是true
否则是false
。如果满足条件则会调用ArithmeticHandler::Read
函数,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/serializer.h
第82行:
/*! \brief Serializer for arithmetic data, handle endianness */
template<typename T>
struct ArithmeticHandler {
......
inline static bool Read(Stream *strm, T *dptr) {
bool ret = strm->Read((void*)dptr, sizeof(T)) == sizeof(T);
......
return ret;
}
};
就是运行时调用子类重写的Read
函数,从文件流中读取一个T
类型元素。接着再看回其他分支选择:
inline static bool Read(Stream *strm, T *data) {
return
IfThenElse<dmlc::is_arithmetic<T>::value,
ArithmeticHandler<T>,
IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
NativePODHandler<T>,
IfThenElse<dmlc::has_saveload<T>::value,
SaveLoadClassHandler<T>,
UndefinedSerializerFor<T>, T>,
T>,
T>
::Read(strm, data);
}
};
如果不满足ArithmeticHandler
的条件,则看下面一个判断dmlc::is_pod<T>
,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/type_traits.h
第20行:
template<typename T>
struct is_pod {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_pod<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = false;
#endif
};
看C++文档解释https://en.cppreference.com/w/cpp/types/is_pod
:
如果是plain old data type
值就是true
,关于PODType
的解释大家可以参考:
https://zhuanlan.zhihu.com/p/29734547
https://en.cppreference.com/w/cpp/named\_req/PODType
而NativePODHandler::Read
函数的实现也是和ArithmeticHandler::Read
类似,也是运行时调用Stream
子类重写的Read
函数,从文件流中读取一个T
类型元素
template<typename T>
struct NativePODHandler {
......
inline static bool Read(Stream *strm, T *dptr) {
return strm->Read((void*)dptr, sizeof(T)) == sizeof(T); // NOLINT(*)
}
};
接着继续回看下一个条件判断dmlc::has_saveload<T>
:
inline static bool Read(Stream *strm, T *data) {
return
IfThenElse<......
IfThenElse<dmlc::has_saveload<T>::value,
SaveLoadClassHandler<T>,
UndefinedSerializerFor<T>, T>,
T>,
T>
::Read(strm, data);
}
};
代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/type_traits.h
第109行:
template<typename T>
struct has_saveload {
/*! \brief the value of the traits */
static const bool value = false;
};
默认值是false
,不过看到${MXNET_ROOT}/include/mxnet/ndarray.h
第1492行:
namespace dmlc {
/*!\brief traits */
DMLC_DECLARE_TRAITS(has_saveload, mxnet::NDArray, true);
} // namespace dmlc
和${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/type_traits.h
第125行,DMLC_DECLARE_TRAITS
的宏定义:
/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TRAITS(Trait, Type, Value) \
template<> \
struct Trait<Type> { \
static const bool value = Value; \
}
就知道对于NDArray
类来说,dmlc::has_saveload<NDArray>::value == true
,所以可以判断如果模板参数类型是NDArray
则会进入SaveLoadClassHandler
实现:
template<typename T>
struct SaveLoadClassHandler {
......
inline static bool Read(Stream *strm, T *data) {
return data->Load(strm);
}
};
实际就是调用了T::Load
函数,也就是NDArray::Load
成员函数。
Handler类还提供了其他模板参数类型的支持比如vector<T>
或者std::string
:
template<typename T>
struct Handler<std::vector<T> > {
......
inline static bool Read(Stream *strm, std::vector<T> *data) {
return IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
NativePODVectorHandler<T>,
ComposeVectorHandler<T>,
std::vector<T> >
::Read(strm, data);
}
};
template<typename T>
struct Handler<std::basic_string<T> > {
.....
inline static bool Read(Stream *strm, std::basic_string<T> *data) {
return IfThenElse<dmlc::is_pod<T>::value && (DMLC_IO_NO_ENDIAN_SWAP || sizeof(T) == 1),
NativePODStringHandler<T>,
UndefinedSerializerFor<T>,
std::basic_string<T> >
::Read(strm, data);
}
};
相信有了前面Handler
类的解释,理解这两个模板类的实现也就容易多了,这里就不展开了有兴趣的读者可以继续去深入了解,下面回来继续看参数NDArray
加载参数逻辑部分的代码解读。
MXNet参数文件解析逻辑
首先给出MXNet参数文件存储内容的格式示意图:
然后根据官方代码的解析逻辑,我自己实现的参数提取代码,为了可读性简化了代码,完整代码见文章开头的github链接:
struct cpu {
static const int kDevMask = 1 << 0;
};
struct gpu {
static const int kDevMask = 1 << 1;
};
enum DeviceType {
kCPU = cpu::kDevMask,
kGPU = gpu::kDevMask,
kCPUPinned = 3,
kCPUShared = 5,
};
static bool Read(std::FILE *fp, void *ptr, size_t size) {
return std::fread(ptr, 1, size, fp) == size;
}
int32_t loadNDArrayV2(std::vector<NDArray *>& ndarrays, std::string param_file) {
std::FILE *fp = fopen(param_file.c_str(), "rb");
uint64_t header, reserved;
Read(fp, (void*)(&header), sizeof(uint64_t));
Read(fp, (void*)(&reserved), sizeof(uint64_t))
uint64_t nd_size;
Read(fp, (void*)(&nd_size), sizeof(uint64_t));
size_t size = static_cast<size_t>(nd_size);
ndarrays.resize(size);
// read nd data
for (size_t i = 0; i < nd_size; ++i) {
NDArray* nd = new NDArray;
ndarrays[i] = nd;
uint32_t magic;
Read(fp, (void*)(&magic), sizeof(uint32_t));
// load storage type
int32_t stype;
Read(fp, (void*)(&stype), sizeof(int32_t));
// load shape
uint32_t ndim_{0};
Read(fp, (void*)(&ndim_), sizeof(uint32_t));
size_t nread = sizeof(int64_t) * ndim_;
int64_t *data_heap_ = new int64_t[ndim_];
Read(fp, (void*)data_heap_, nread);
int64_t size = 1;
for (uint32_t i=0; i<ndim_;++i) {
size *= data_heap_[i];
nd->shape.push_back(data_heap_[i]);
}
delete[] data_heap_;
// load context
DeviceType dev_type;
int32_t dev_id;
Read(fp, (void*)(&dev_type), sizeof(dev_type));
Read(fp, (void*)(&dev_id), sizeof(int32_t));
// load type flag
int32_t type_flag;
Read(fp, (void*)(&type_flag), sizeof(int32_t));
size_t all_size = size * mshadow_sizeof(type_flag);
nd->numOfBytes = all_size;
nd->data = (void *)malloc(all_size);
Read(fp, nd->data, nd->numOfBytes);
}
// read nd names
std::vector<std::string> keys;
uint64_t keysLen;
Read(fp, (void*)(&keysLen), sizeof(uint64_t));
keys.resize(keysLen);
for (uint64_t k = 0; k < keysLen; ++k) {
uint64_t stringLen;
Read(fp, (void*)(&stringLen), sizeof(uint64_t));
size_t size = static_cast<size_t>(stringLen);
keys[k].resize(size);
if (size != 0) {
size_t nbytes = sizeof(char) * size;
Read(fp, (void*)(&(keys[k][0])), nbytes);
}
}
std::fclose(fp);
return kSuccess;
}
结合示意图和代码,应该就能比较好的理解参数文件的存储格式和解析方法了。
推荐文章
更多AI移动端优化的请关注专栏嵌入式AI以及知乎(@梁德澎)。