載入資料集#
使用 Jax+Flax 編寫的神經網路會將輸入資料視為 jax.numpy
陣列實例。因此,從任何來源載入資料集只要將其轉換為 jax.numpy
類型,並將其重新塑造成網路的適當維度就十分簡單了。
以下這個範例示範如何利用 Torchvision、Tensorflow 和 Hugging Face 的 API 匯入 MNIST。我們會將整個資料集載入記憶體中。對於不符合記憶體大小的資料集,處理方式相似,但應以批次方式進行。
MNIST 資料集由 28x28 像素的灰階手寫數字影像組成,並指定為 60k/10k 訓練/測試分割。任務是預測每張影像的正確類別 (數字 0, …, 9)。
假設我們使用基於 CNN 的分類器,輸入資料的形狀應為 (B, 28, 28, 1)
,其中最後的單一維度表示灰階影像通道。
標籤只是表示與影像對應的數字的整數。因此,標籤應為形狀 (B,)
,才能使用 optax.softmax_cross_entropy_with_integer_labels
進行損失計算。
import numpy as np
import jax.numpy as jnp
從 torchvision.datasets
載入#
import torchvision
def get_dataset_torch():
mnist = {
'train': torchvision.datasets.MNIST('./data', train=True, download=True),
'test': torchvision.datasets.MNIST('./data', train=False, download=True)
}
ds = {}
for split in ['train', 'test']:
ds[split] = {
'image': mnist[split].data.numpy(),
'label': mnist[split].targets.numpy()
}
# cast from np to jnp and rescale the pixel values from [0,255] to [0,1]
ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
ds[split]['label'] = jnp.int16(ds[split]['label'])
# torchvision returns shape (B, 28, 28).
# hence, append the trailing channel dimension.
ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)
return ds['train'], ds['test']
train, test = get_dataset_torch()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16
從 tensorflow_datasets
載入#
import tensorflow_datasets as tfds
def get_dataset_tf():
mnist = tfds.builder('mnist')
mnist.download_and_prepare()
ds = {}
for split in ['train', 'test']:
ds[split] = tfds.as_numpy(mnist.as_dataset(split=split, batch_size=-1))
# cast to jnp and rescale pixel values
ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
ds[split]['label'] = jnp.int16(ds[split]['label'])
return ds['train'], ds['test']
train, test = get_dataset_tf()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16
從 🤗 Hugging Face datasets
載入#
#!pip install datasets # datasets isn't preinstalled on Colab; uncomment to install
from datasets import load_dataset
def get_dataset_hf():
mnist = load_dataset("mnist")
ds = {}
for split in ['train', 'test']:
ds[split] = {
'image': np.array([np.array(im) for im in mnist[split]['image']]),
'label': np.array(mnist[split]['label'])
}
# cast to jnp and rescale pixel values
ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
ds[split]['label'] = jnp.int16(ds[split]['label'])
# append trailing channel dimension
ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)
return ds['train'], ds['test']
train, test = get_dataset_hf()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16