【tf.keras】tensorflow datasets,tfds

一些最經常使用的數據集如 MNIST、Fashion MNIST、cifar10/100 在 tf.keras.datasets 中就能找到,但對於其它也經常使用的數據集如 SVHN、Caltech101,tf.keras.datasets 中沒有,此時咱們能夠在 TensorFlow Datasets 中找找看。python

tensorflow_datasets 裏面包含的數據集列表:https://www.tensorflow.org/datasets/catalog/overview#all_datasetsgit

tensorflow_datasets 安裝:pip install tensorflow_datasetsgithub

tensorflow_datasets 示例:

獲得 tf.data.Dataset 對象:code

import tensorflow as tf
import tensorflow_datasets as tfds

data, info = tfds.load("mnist", with_info=True)
print(info)

train_data, test_data = data['train'], data['test']
assert isinstance(train_data, tf.data.Dataset)
print(train_data)

獲得 numpy.ndarray 對象:對象

import tensorflow_datasets as tfds
# `batch_size=-1`, will return the full dataset as `tf.Tensor`s.
dataset, info = tfds.load("mnist", batch_size=-1, with_info=True)
print(info)
train, test = dataset["train"], dataset["test"]
print(type(train['image']))

train = tfds.as_numpy(train)
print(type(train['image']))
print(train['image'].shape)
print(train['label'].shape)

tf.data.Dataset 進行簡單劃分驗證集能夠參考 https://github.com/tensorflow/datasets/issues/665#issuecomment-502409920ip

若是想對 MNIST 等數據集手動分層隨機劃分出一個驗證集,仍是轉化成 numpy.ndarray 比較方便,再使用 sklearn 的 train_test_split 方法一行代碼就能夠搞定。ci

References

https://www.tensorflow.org/datasets
https://www.tensorflow.org/datasets/catalog/overview#all_datasets
https://github.com/tensorflow/datasets/blob/master/docs/splits.mdget

相關文章
相關標籤/搜索