梁德澎 · 2020年05月13日

MXNet源码解读笔记一 ---- 如何解析参数文件

本文首发于 MXNet与Gluon的黑科技工具箱写文章
作者:梁德澎

前言

本文主要内容是解读MXNet是加载模型参数文件并解析得到NDArray所涉及到的代码,希望读者读完本文能对MXNet参数文件的格式有清晰的了解,并可以自己来实现参数文件的解析。

解析MXNet参数文件C++小工程:

https://github.com/Ldpe2G/DeepLearningForFun/tree/master/MXNet-Cpp/parsingNDArray​github.com

本文解读的MXNet代码基于的版本:commit 7d2c9bf3b631433132452760734b684e39170814

Python前端代码入口

首先从MXNet Python前端看是如何是调用底层C接口来读取NDArray参数文件的,这部分代码见源码 ${MXNET_ROOT}/python/mxnet/ndarray/utils.py 第149行:

def load(fname):
    if not isinstance(fname, string_types):
        raise TypeError('fname required to be a string')
    out_size = mx_uint()
    out_name_size = mx_uint()
    handles = ctypes.POINTER(NDArrayHandle)()
    names = ctypes.POINTER(ctypes.c_char_p)()
    check_call(_LIB.MXNDArrayLoad(c_str(fname),
                                  ctypes.byref(out_size),
                                  ctypes.byref(handles),
                                  ctypes.byref(out_name_size),
                                  ctypes.byref(names)))
    .....

这个 load 函数接收参数路径作为输入,主要是调用 MXNDArrayLoad 这个底层的C函数接口来读取参数。MXNet底层实现是C++,并提供了一层C的接口供前端语言去调用。

C接口层

接着来看下MXNDArrayLoad接口的实现,这部分代码见${MXNET_ROOT}/src/c_api/c_api.cc第1344行:

int MXNDArrayLoad(const char* fname,
                  uint32_t *out_size,
                  NDArrayHandle** out_arr,
                  uint32_t *out_name_size,
                  const char*** out_names) {
  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
  ret->ret_vec_str.clear();
  API_BEGIN();
  std::vector<NDArray> data;
  std::vector<std::string> &names = ret->ret_vec_str;
  {
    std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
    mxnet::NDArray::Load(fi.get(), &data, &names);
  }
  ......
}

核心代码就是首先打开文件流,接着调用NDArray类的静态函数mxnet::NDArray::Load函数读取并解析参数文件,得到参数NDArray数组保存到data这个变量里面。我们只要关注dmlc::Stream这个类的实现还有mxnet::NDArray::Load这个类的实现就可以了。

底层C++实现

NDArray::Load静态函数

函数具体实现见${MXNET_ROOT}/src/ndarray/ndarray.cc第 1924 行:

void NDArray::Load(dmlc::Stream* fi,
                   std::vector<NDArray>* data,
                   std::vector<std::string>* keys) {
  uint64_t header, reserved;
  CHECK(fi->Read(&header))
      << "Invalid NDArray file format";
  CHECK(fi->Read(&reserved))
      << "Invalid NDArray file format";
  CHECK(header == kMXAPINDArrayListMagic)
      << "Invalid NDArray file format";
  CHECK(fi->Read(data))
      << "Invalid NDArray file format";
  CHECK(fi->Read(keys))
      << "Invalid NDArray file format";
  CHECK(keys->size() == 0 || keys->size() == data->size())
      << "Invalid NDArray file format";
}

从这里读取内容的过程可以大概看出NDArray参数文件存储的内容格式。

首先文件开头保存了两个uint64\_t类型的数字,接着就是NDArray参数数组,接着是每个NDArray对应的名字数组。

读取的时候都是调用Strem类的Read函数,接下来就是看下Stream类的实现。

Stream类

看回上面打开参数文件的代码:

std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));

dmlc::Stream::Create代码见dmlc-core子模块:${MXNET_ROOT}/3rdparty/dmlc-core/src/io.cc第132行:

Stream *Stream::Create(const char *uri,
                       const char * const flag,
                       bool try_create) {
  io::URI path(uri);
  return io::FileSystem::
      GetInstance(path)->Open(path, flag, try_create);
}

调用了FileSystem::GetINstance函数得到实例,并调用Open函数打开文件,这里返回的实例是LocalFileSystem类的实例,其Open函数见${MXNET_ROOT}/3rdparty/dmlc-core/src/io/local_filesys.cc第147行:

SeekStream *LocalFileSystem::Open(const URI &path,
                                  const char* const mode,
                                  bool allow_null) {
  FILE *fp = NULL;
  const char *fname = path.name.c_str();
  using namespace std;
  std::string flag = mode;
  if (flag == "r") flag = "rb";
  fp = fopen(fname, flag.c_str());
  if (fp != NULL) {
    return new FileStream(fp, false);
  } else {
    return NULL;
  }
}

为了可读性我简化了代码,可以看到就是调用std::fopen函数打开文件,并把FILE指针传给FileStream类,

代码见${MXNET_ROOT}/3rdparty/dmlc-core/src/io/local_filesys.cc第27行:

class FileStream : public SeekStream {
 public:
  explicit FileStream(FILE *fp, bool use_stdio)
      : fp_(fp), use_stdio_(use_stdio) {}
  virtual ~FileStream(void) {
    this->Close();
  }
  virtual size_t Read(void *ptr, size_t size) {
    return std::fread(ptr, 1, size, fp_);
  }

  ......
  
 private:
  std::FILE *fp_;
  bool use_stdio_;
};

可以看到FileStream继承自SeekStrem,而且成员函数Read实现的功能是调用std::fread函数从fp_文件指针里面读取size大小字节的内容,std::fread的文档见https://en.cppreference.com/w/cpp/io/c/fread

看下每个参数的解释:

buffer - void 指针指向从文件留中读取内容的存取内存
  size   - 指针指向内存每个元素字节大小,这里由于是void指针,所以size大小恒为1
  count  - 读取元素个数
  stream - 文件流

MXNet这里的实现是把需要被读取的内存指针转换成void *,这样子就可以兼容各种基本类型的指针读取,只需要记住传入的读取元素个数是 sizeof(T) * count,就是原来类型元素个数乘以每个元素对应的字节数。

接着再看回上面打开参数文件的代码:

std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));

Strem::Create返回的是Strem类型,而不是SeekStrem,所以继续往上找Strem类的定义,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/io.h第30行:

class Stream {  // NOLINT(*)
 public:

  virtual size_t Read(void *ptr, size_t size) = 0;

  static Stream *Create(const char *uri,
                        const char* const flag,
                        bool allow_null = false);
  
  template<typename T>
  inline bool Read(T *out_data);

  template<typename T>
  inline bool ReadArray(T* data, size_t num_elems);
};

为了可读性,只保留了读文件相关的代码,可以看到,FileStream重写了virtual size_t Read(void *ptr, size_t size) = 0;虚函数,

而回看NDArray静态Load函数:

void NDArray::Load(dmlc::Stream* fi,
                   std::vector<NDArray>* data,
                   std::vector<std::string>* keys) {
  uint64_t header, reserved;
  CHECK(fi->Read(&header))
      << "Invalid NDArray file format";
  CHECK(fi->Read(&reserved))
      << "Invalid NDArray file format";
  CHECK(header == kMXAPINDArrayListMagic)
      << "Invalid NDArray file format";
  CHECK(fi->Read(data))
      << "Invalid NDArray file format";
  CHECK(fi->Read(keys))
      << "Invalid NDArray file format";
  CHECK(keys->size() == 0 || keys->size() == data->size())
      << "Invalid NDArray file format";
}

具体读参数文件内容的时候调用的是Stream类的Read(T *out_data)模板函数

  template<typename T>
  inline bool Read(T *out_data);

这个模板函数的实现很有意思,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/io.h第455:

template<typename T>
inline bool Stream::Read(T *out_data) {
  return serializer::Handler<T>::Read(this, out_data);
}

可以看到调用了Handler::Read函数,继续跟进去看Handler的实现,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/serializer.h第258行:

template<typename T>
struct Handler {
  ......
  /*!
   * \brief read data to stream
   * \param strm the stream to read the data.
   * \param data the pointer to the data obeject to read
   * \return whether the read is successful
   */
  inline static bool Read(Stream *strm, T *data) {
    return
    IfThenElse<dmlc::is_arithmetic<T>::value,
               ArithmeticHandler<T>,
               IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
                          NativePODHandler<T>,
                          IfThenElse<dmlc::has_saveload<T>::value,
                                     SaveLoadClassHandler<T>,
                                     UndefinedSerializerFor<T>, T>,
                          T>,
               T>
    ::Read(strm, data);
  }
};

一开始看到这串代码可能会有点懵,不过没关系,接下来我们就一步步拆解这段代码,首先看IfThenElse结构体的实现,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/serializer.h第48行:

template<bool cond, typename Then, typename Else, typename Return>
struct IfThenElse;

template<typename Then, typename Else, typename T>
struct IfThenElse<true, Then, Else, T> {
  ......
  inline static bool Read(Stream *strm, T *data) {
    return Then::Read(strm, data);
  }
};
template<typename Then, typename Else, typename T>
struct IfThenElse<false, Then, Else, T> {
  ......
  inline static bool Read(Stream *strm, T *data) {
    return Else::Read(strm, data);
  }
};

我理解就是根据模板参数在编译期间做分支选择,根据模板参数决定调用实现分支,这里可以看到如果第一个模板参数template<bool cond, ...>true的话,就调用 Then::Read函数 ,否则调用Else::Read函数,然后看回Handler::Read函数:

  inline static bool Read(Stream *strm, T *data) {
    return
    IfThenElse<dmlc::is_arithmetic<T>::value,
               ArithmeticHandler<T>,
               IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
                          NativePODHandler<T>,
                          IfThenElse<dmlc::has_saveload<T>::value,
                                     SaveLoadClassHandler<T>,
                                     UndefinedSerializerFor<T>, T>,
                          T>,
               T>
    ::Read(strm, data);
  }
};

代码就很好理解了,如果dmlc::is_arithmetic<T>::value值为true则走ArithmeticHandler<T>,否则再进行第二次判断,先来看下具体实现${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/type_traits.h第66行:

template<typename T>
struct is_arithmetic {
  static const bool value = std::is_arithmetic<T>::value;
};

为了可读性我简化了下代码,只要查下C++的文档看下std::is_arithmetic<T>的定义就知道什么模板参数类型是什么的情况下值是true或者false,C++文档解释见https://en.cppreference.com/w/cpp/types/is_arithmetic

也就是如果模板类型T是整数型或者浮点型,value的值就是true否则是false。如果满足条件则会调用ArithmeticHandler::Read函数,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/serializer.h第82行:

/*! \brief Serializer for arithmetic data, handle endianness */
template<typename T>
struct ArithmeticHandler {
  ......
  inline static bool Read(Stream *strm, T *dptr) {
    bool ret = strm->Read((void*)dptr, sizeof(T)) == sizeof(T);  
    ......
    return ret;
  }
};

就是运行时调用子类重写的Read函数,从文件流中读取一个T类型元素。接着再看回其他分支选择:

  inline static bool Read(Stream *strm, T *data) {
    return
    IfThenElse<dmlc::is_arithmetic<T>::value,
               ArithmeticHandler<T>,
               IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
                          NativePODHandler<T>,
                          IfThenElse<dmlc::has_saveload<T>::value,
                                     SaveLoadClassHandler<T>,
                                     UndefinedSerializerFor<T>, T>,
                          T>,
               T>
    ::Read(strm, data);
  }
};

如果不满足ArithmeticHandler的条件,则看下面一个判断dmlc::is_pod<T>,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/type_traits.h第20行:

template<typename T>
struct is_pod {
#if DMLC_USE_CXX11
  /*! \brief the value of the traits */
  static const bool value = std::is_pod<T>::value;
#else
  /*! \brief the value of the traits */
  static const bool value = false;
#endif
};

看C++文档解释https://en.cppreference.com/w/cpp/types/is_pod

如果是plain old data type值就是true,关于PODType的解释大家可以参考:

https://zhuanlan.zhihu.com/p/29734547

https://en.cppreference.com/w/cpp/named\_req/PODType

NativePODHandler::Read函数的实现也是和ArithmeticHandler::Read类似,也是运行时调用Stream子类重写的Read函数,从文件流中读取一个T类型元素

template<typename T>
struct NativePODHandler {
  ......
  inline static bool Read(Stream *strm, T *dptr) {
    return strm->Read((void*)dptr, sizeof(T)) == sizeof(T);  // NOLINT(*)
  }
};

接着继续回看下一个条件判断dmlc::has_saveload<T>

  inline static bool Read(Stream *strm, T *data) {
    return
    IfThenElse<......
                          IfThenElse<dmlc::has_saveload<T>::value,
                                     SaveLoadClassHandler<T>,
                                     UndefinedSerializerFor<T>, T>,
                          T>,
               T>
    ::Read(strm, data);
  }
};

代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/type_traits.h第109行:

template<typename T>
struct has_saveload {
  /*! \brief the value of the traits */
  static const bool value = false;
};

默认值是false,不过看到${MXNET_ROOT}/include/mxnet/ndarray.h第1492行:

namespace dmlc {
/*!\brief traits */
DMLC_DECLARE_TRAITS(has_saveload, mxnet::NDArray, true);
}  // namespace dmlc

${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/type_traits.h第125行,DMLC_DECLARE_TRAITS的宏定义:

/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TRAITS(Trait, Type, Value)       \
  template<>                                          \
  struct Trait<Type> {                                \
    static const bool value = Value;                  \
  }

就知道对于NDArray类来说,dmlc::has_saveload<NDArray>::value == true,所以可以判断如果模板参数类型是NDArray则会进入SaveLoadClassHandler实现:

template<typename T>
struct SaveLoadClassHandler {
  ......
  inline static bool Read(Stream *strm, T *data) {
    return data->Load(strm);
  }
};

实际就是调用了T::Load函数,也就是NDArray::Load成员函数。

Handler类还提供了其他模板参数类型的支持比如vector<T>或者std::string

template<typename T>
struct Handler<std::vector<T> > {
  ......
  inline static bool Read(Stream *strm, std::vector<T> *data) {
    return IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
                      NativePODVectorHandler<T>,
                      ComposeVectorHandler<T>,
                      std::vector<T> >
    ::Read(strm, data);
  }
};

template<typename T>
struct Handler<std::basic_string<T> > {
  .....
  inline static bool Read(Stream *strm, std::basic_string<T> *data) {
    return IfThenElse<dmlc::is_pod<T>::value && (DMLC_IO_NO_ENDIAN_SWAP || sizeof(T) == 1),
                      NativePODStringHandler<T>,
                      UndefinedSerializerFor<T>,
                      std::basic_string<T> >
    ::Read(strm, data);
  }
};

相信有了前面Handler类的解释,理解这两个模板类的实现也就容易多了,这里就不展开了有兴趣的读者可以继续去深入了解,下面回来继续看参数NDArray加载参数逻辑部分的代码解读。

MXNet参数文件解析逻辑

首先给出MXNet参数文件存储内容的格式示意图:

然后根据官方代码的解析逻辑,我自己实现的参数提取代码,为了可读性简化了代码,完整代码见文章开头的github链接:

struct cpu {
  static const int kDevMask = 1 << 0;
};
struct gpu {
  static const int kDevMask = 1 << 1;
};

enum DeviceType {
  kCPU = cpu::kDevMask,
  kGPU = gpu::kDevMask,
  kCPUPinned = 3,
  kCPUShared = 5,
};

static bool Read(std::FILE *fp, void *ptr, size_t size) {
  return std::fread(ptr, 1, size, fp) == size;
} 

int32_t loadNDArrayV2(std::vector<NDArray *>& ndarrays, std::string param_file) {

  std::FILE *fp = fopen(param_file.c_str(), "rb");

  uint64_t header, reserved;
  Read(fp, (void*)(&header), sizeof(uint64_t));
  Read(fp, (void*)(&reserved), sizeof(uint64_t))

  uint64_t nd_size;
  Read(fp, (void*)(&nd_size), sizeof(uint64_t));

  size_t size = static_cast<size_t>(nd_size);
  ndarrays.resize(size);

  // read nd data
  for (size_t i = 0; i < nd_size; ++i) {
    NDArray* nd = new NDArray;
    ndarrays[i] = nd;

    uint32_t magic;
    Read(fp, (void*)(&magic), sizeof(uint32_t));

    // load storage type
    int32_t stype;
    Read(fp, (void*)(&stype), sizeof(int32_t));

    // load shape
    uint32_t ndim_{0};
    Read(fp, (void*)(&ndim_), sizeof(uint32_t));

    size_t nread = sizeof(int64_t) * ndim_;
    int64_t *data_heap_ = new int64_t[ndim_];
    Read(fp, (void*)data_heap_, nread);

    int64_t size = 1;
    for (uint32_t i=0; i<ndim_;++i) {
      size *= data_heap_[i];
      nd->shape.push_back(data_heap_[i]);
    }

    delete[] data_heap_;

    // load context 
    DeviceType dev_type;
    int32_t dev_id;
    Read(fp, (void*)(&dev_type), sizeof(dev_type));

    Read(fp, (void*)(&dev_id), sizeof(int32_t));

    // load type flag
    int32_t type_flag;
    Read(fp, (void*)(&type_flag), sizeof(int32_t));

    size_t all_size = size * mshadow_sizeof(type_flag);
    nd->numOfBytes = all_size;
    nd->data = (void *)malloc(all_size);

    Read(fp, nd->data, nd->numOfBytes);
  }

  // read nd names
  std::vector<std::string> keys;
  uint64_t keysLen;
  Read(fp, (void*)(&keysLen), sizeof(uint64_t));
  keys.resize(keysLen);

  for (uint64_t k = 0; k < keysLen; ++k) {
     uint64_t stringLen;
     Read(fp, (void*)(&stringLen), sizeof(uint64_t));
    size_t size = static_cast<size_t>(stringLen);
    keys[k].resize(size);
    if (size != 0) {
      size_t nbytes = sizeof(char) * size;
      Read(fp, (void*)(&(keys[k][0])), nbytes);
    }
  }
  std::fclose(fp);
  return kSuccess;
}

结合示意图和代码,应该就能比较好的理解参数文件的存储格式和解析方法了。



推荐文章


更多AI移动端优化的请关注专栏嵌入式AI以及知乎(@梁德澎)。
推荐阅读
关注数
16750
内容数
1233
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息