TensorFlow推奨のデータ形式TFRecordを使う

概要

Tensorflow推奨のデータ形式.あらかじめ画像などをTFRecord形式で吐いておくとかなり効率的に学習・推論ができる. コンバータも簡単に書けるのでおすすめ.変換速度もまあまあ速い.

使い方

Convert Images to TFRecord

def convert(data_root):
  images, labels = get_images(data_root) # 画像とラベルのndarrayを返す関数
  fname = "dataset.tfrecords"
  with tf.python.io.TFRecordWriter(fname):
    for i in range(len(images)):
      image_raw = images[i].tostring()
      example = tf.train.Example(
              features=tf.train.Features(
                    feature={
                          "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[i]])),
                          "image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])))
                    }))
      writer.write(example.SerializeToString())

Read TFRecord

features = tf.parse_single_example(
  serialized_example,
  features={
    "image_raw": tf.FixedLenFeature([], tf.string),
    "label": tf.FixedLenFeature([], tf.int64)})

image = tf.decode_raw(features["image_raw"], tf.uint8)