載入資料集#

Open in Colab

使用 Jax+Flax 編寫的神經網路會將輸入資料視為 jax.numpy 陣列實例。因此,從任何來源載入資料集只要將其轉換為 jax.numpy 類型,並將其重新塑造成網路的適當維度就十分簡單了。

以下這個範例示範如何利用 Torchvision、Tensorflow 和 Hugging Face 的 API 匯入 MNIST。我們會將整個資料集載入記憶體中。對於不符合記憶體大小的資料集,處理方式相似,但應以批次方式進行。

MNIST 資料集由 28x28 像素的灰階手寫數字影像組成,並指定為 60k/10k 訓練/測試分割。任務是預測每張影像的正確類別 (數字 0, …, 9)。

假設我們使用基於 CNN 的分類器,輸入資料的形狀應為 (B, 28, 28, 1),其中最後的單一維度表示灰階影像通道。

標籤只是表示與影像對應的數字的整數。因此,標籤應為形狀 (B,),才能使用 optax.softmax_cross_entropy_with_integer_labels 進行損失計算。

import numpy as np
import jax.numpy as jnp

torchvision.datasets 載入#

import torchvision
def get_dataset_torch():
    mnist = {
        'train': torchvision.datasets.MNIST('./data', train=True, download=True),
        'test': torchvision.datasets.MNIST('./data', train=False, download=True)
    }

    ds = {}

    for split in ['train', 'test']:
        ds[split] = {
            'image': mnist[split].data.numpy(),
            'label': mnist[split].targets.numpy()
        }

        # cast from np to jnp and rescale the pixel values from [0,255] to [0,1]
        ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
        ds[split]['label'] = jnp.int16(ds[split]['label'])

        # torchvision returns shape (B, 28, 28).
        # hence, append the trailing channel dimension.
        ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)

    return ds['train'], ds['test']
train, test = get_dataset_torch()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16

tensorflow_datasets 載入#

import tensorflow_datasets as tfds
def get_dataset_tf():
    mnist = tfds.builder('mnist')
    mnist.download_and_prepare()

    ds = {}

    for split in ['train', 'test']:
        ds[split] = tfds.as_numpy(mnist.as_dataset(split=split, batch_size=-1))

        # cast to jnp and rescale pixel values
        ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
        ds[split]['label'] = jnp.int16(ds[split]['label'])

    return ds['train'], ds['test']
train, test = get_dataset_tf()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16

從 🤗 Hugging Face datasets 載入#

#!pip install datasets # datasets isn't preinstalled on Colab; uncomment to install
from datasets import load_dataset
def get_dataset_hf():
    mnist = load_dataset("mnist")

    ds = {}

    for split in ['train', 'test']:
        ds[split] = {
            'image': np.array([np.array(im) for im in mnist[split]['image']]),
            'label': np.array(mnist[split]['label'])
        }

        # cast to jnp and rescale pixel values
        ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
        ds[split]['label'] = jnp.int16(ds[split]['label'])

        # append trailing channel dimension
        ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)

    return ds['train'], ds['test']
train, test = get_dataset_hf()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16