PyTorchのC++ APIで自作のデータセットの定義の仕方を紹介. 今回は以下のようなデータセットを作る,
$ tree test_data/ test_data/ |-- images | |-- image1.jpg | |-- image2.jpg | `-- image3.jpg `-- labels.txt
- ラベルを定義したファイル:
labelfile
- ラベルは以下のように定義
$ cat test_data/labels.txt image1.jpg,0 image2.jpg,0 image3.jpg,1
PyTorch C++ APIによるデータセットの実装
torch::data::datasets::Dataset
を継承して実装をおこなう- 実装しなければいけない関数で絶対に必要なのは以下の2つ
torch::data::Example<> get(size_t index)
at::optional<size_t> size() const
#include <dirent.h> #include <iostream> #include <map> #include <sstream> #include <string> #include <tuple> #include <vector> #include <opencv2/core/core.hpp> #include <opencv2/highgui/highgui.hpp> #include <torch/torch.h> #include <torch/data/datasets/base.h> #include <torch/data/example.h> #include <torch/types.h> // Split label info. std::tuple<std::string, std::string> split(std::string &s, char delim) { std::stringstream ss(s); std::vector<std::string> elems; std::string item; while (getline(ss, item, delim)) { if (!item.empty()) { elems.push_back(item); } } return std::forward_as_tuple(elems[0], elems[1]); } // Define own dataset class ImageDataset : public torch::data::datasets::Dataset<ImageDataset> { private: std::string root; std::vector<std::string> files; std::map<std::string, int> labels; public: explicit ImageDataset(const std::string root, const std::string labelfile) : root(root) { // get files auto p = opendir(root.c_str()); dirent* entry; if(p != nullptr) { do { entry = readdir(p); if(entry != nullptr) { if(strcmp(entry->d_name, ".\0") == 0 || strcmp(entry->d_name, "..\0") == 0) continue; files.push_back(entry->d_name); } } while(entry != nullptr); } // get labels std::ifstream fs(labelfile); std::string buf; std::string fname, label; while(fs >> buf) { std::tie(fname, label) = split(buf, ','); labels[fname] = stoi(label); } } torch::data::Example<> get(size_t index) override { std::string fname = this->root + this->files[index]; std::cout << fname << std::endl; int label = this->labels.at(files[index]); cv::Mat image = cv::imread(fname, 1); std::vector<int64_t> sizes = {1, 3, image.rows, image.cols}; at::Tensor tensor_image = torch::from_blob(image.data, at::IntList(sizes), at::ScalarType::Byte); at::Tensor tensor_label = torch::tensor({label}, torch::dtype(torch::kUInt8)); tensor_image = tensor_image.toType(at::kFloat); return {tensor_image, tensor_label}; } at::optional<size_t> size() const override { return this->files.size(); } }; int main(int argc, char **argv) { std::string root = argv[1]; std::string labelfile = argv[2]; ImageDataset dataset(root, labelfile); auto batch = dataset.get(0); std::cout << "input dim: " << batch.data.dim() << std::endl; std::cout << "target: " << batch.target << std::endl; return 0; }
実行結果
$ ./dataset ../test_data/ ../test_data/labels.txt input dim: 4 target: 1 [ Variable[CPUByteType]{1} ]
- C++ Frontendの情報はほとんど出回ってないのでサンプルコードは積極的に上げていきたい
- ところで私もまだまだ書き慣れていないため名前空間が
at
とtorch
で揺れてたりする - 公式にプルリク投げた (https://github.com/pytorch/examples/pull/506)