Open in Colab Open On GitHub

MNIST 教學#

歡迎使用 Flax NNX!在本教學中,您將學習如何使用 Flax NNX API 建構和訓練一個簡單的卷積神經網路 (CNN),以對 MNIST 資料集上的手寫數字進行分類。

Flax NNX 是一個基於 JAX 建構的 Python 神經網路函式庫。如果您之前使用過 Flax Linen API,請查看為什麼選擇 Flax NNX。您應該具備一些深度學習的主要概念知識。

讓我們開始吧!

1. 安裝 Flax#

如果您的 Python 環境中沒有安裝 flax,請使用 pip 從 PyPI 安裝套件(如果使用 Google Colab/Jupyter Notebook,請取消註解單元格中的程式碼)

# !pip install flax

2. 載入 MNIST 資料集#

首先,您需要載入 MNIST 資料集,然後透過 Tensorflow Datasets (TFDS) 準備訓練集和測試集。您需要正規化影像值、打亂資料並將其分割成批次,以及預取樣本以提高效能。

import tensorflow_datasets as tfds  # TFDS to download MNIST.
import tensorflow as tf  # TensorFlow / `tf.data` operations.

tf.random.set_seed(0)  # Set the random seed for reproducibility.

train_steps = 1200
eval_every = 200
batch_size = 32

train_ds: tf.data.Dataset = tfds.load('mnist', split='train')
test_ds: tf.data.Dataset = tfds.load('mnist', split='test')

train_ds = train_ds.map(
  lambda sample: {
    'image': tf.cast(sample['image'], tf.float32) / 255,
    'label': sample['label'],
  }
)  # normalize train set
test_ds = test_ds.map(
  lambda sample: {
    'image': tf.cast(sample['image'], tf.float32) / 255,
    'label': sample['label'],
  }
)  # Normalize the test set.

# Create a shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from.
train_ds = train_ds.repeat().shuffle(1024)
# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.
train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)
# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
/usr/local/google/home/cgarciae/flax/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
2024-07-10 15:24:11.227958: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-07-10 15:24:12.227896: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

3. 使用 Flax NNX 定義模型#

透過繼承 nnx.Module,使用 Flax NNX 建立用於分類的 CNN

from flax import nnx  # The Flax NNX API.
from functools import partial

class CNN(nnx.Module):
  """A simple CNN model."""

  def __init__(self, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
    self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
    self.linear2 = nnx.Linear(256, 10, rngs=rngs)

  def __call__(self, x):
    x = self.avg_pool(nnx.relu(self.conv1(x)))
    x = self.avg_pool(nnx.relu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.relu(self.linear1(x))
    x = self.linear2(x)
    return x

# Instantiate the model.
model = CNN(rngs=nnx.Rngs(0))
# Visualize it.
nnx.display(model)
(載入中...)

執行模型#

讓我們測試一下 CNN 模型!在此,您將使用任意資料執行正向傳遞並印出結果。

import jax.numpy as jnp  # JAX NumPy

y = model(jnp.ones((1, 28, 28, 1)))
nnx.display(y)
(載入中...)

4. 建立優化器並定義一些指標#

在 Flax NNX 中,您需要建立一個 nnx.Optimizer 物件,以管理模型的參數並在訓練期間應用梯度。nnx.Optimizer 接收模型的引用,以便它可以更新其參數,並接收一個 Optax 優化器來定義更新規則。此外,您將定義一個 nnx.MultiMetric 物件,以追蹤 AccuracyAverage 損失。

import optax

learning_rate = 0.005
momentum = 0.9

optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(
  accuracy=nnx.metrics.Accuracy(),
  loss=nnx.metrics.Average('loss'),
)

nnx.display(optimizer)
(載入中...)

5. 定義訓練步驟函數#

在本節中,您將使用交叉熵損失(optax.softmax_cross_entropy_with_integer_labels())定義一個損失函數,CNN 模型將在此函數上進行優化。

除了 loss 之外,在訓練和測試期間,您還將獲得 logits,它將用於計算準確度指標。

在訓練期間(train_step),您將使用 nnx.value_and_grad 計算梯度並使用您已定義的 optimizer 更新模型的參數。在訓練和測試期間(eval_step),losslogits 將用於計算指標。

def loss_fn(model: CNN, batch):
  logits = model(batch['image'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['label']
  ).mean()
  return loss, logits

@nnx.jit
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """Train for a single step."""
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['label'])  # In-place updates.
  optimizer.update(grads)  # In-place updates.

@nnx.jit
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
  loss, logits = loss_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['label'])  # In-place updates.

在上面的程式碼中,nnx.jit 轉換裝飾器會追蹤 train_step 函數,以便使用 XLA 進行即時編譯,從而優化硬體加速器(例如 Google TPU 和 GPU)的效能。nnx.jitjax.jit 轉換的「提升」版本,允許其函數輸入和輸出成為 Flax NNX 物件。同樣地,nnx.value_and_gradjax.value_and_grad 的提升版本。請查看提升轉換指南以了解更多資訊。

注意: 程式碼顯示了如何對模型、優化器和指標執行多個原地更新,但並未明確傳回狀態更新。這是因為 Flax NNX 轉換遵循 Flax NNX 物件的參考語意,並且會傳播作為輸入引數傳遞的物件的狀態更新。這是 Flax NNX 的一個關鍵功能,可讓程式碼更簡潔易讀。您可以在為什麼選擇 Flax NNX中了解更多資訊。

6. 訓練和評估模型#

現在,您可以使用批次的資料訓練 CNN 模型 10 個 epoch,在每個 epoch 後評估模型在測試集上的效能,並在過程中記錄訓練和測試指標(損失和準確度)。通常這會使模型達到大約 99% 的準確度。

metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'test_loss': [],
  'test_accuracy': [],
}

for step, batch in enumerate(train_ds.as_numpy_iterator()):
  # Run the optimization for one step and make a stateful update to the following:
  # - The train state's model parameters
  # - The optimizer state
  # - The training loss and accuracy batch metrics
  train_step(model, optimizer, metrics, batch)

  if step > 0 and (step % eval_every == 0 or step == train_steps - 1):  # One training epoch has passed.
    # Log the training metrics.
    for metric, value in metrics.compute().items():  # Compute the metrics.
      metrics_history[f'train_{metric}'].append(value)  # Record the metrics.
    metrics.reset()  # Reset the metrics for the test set.

    # Compute the metrics on the test set after each training epoch.
    for test_batch in test_ds.as_numpy_iterator():
      eval_step(model, metrics, test_batch)

    # Log the test metrics.
    for metric, value in metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)
    metrics.reset()  # Reset the metrics for the next training epoch.

    print(
      f"[train] step: {step}, "
      f"loss: {metrics_history['train_loss'][-1]}, "
      f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
    )
    print(
      f"[test] step: {step}, "
      f"loss: {metrics_history['test_loss'][-1]}, "
      f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
    )
2024-07-10 15:24:26.290421: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
[train] step: 200, loss: 0.3102289140224457, accuracy: 90.08084869384766
[test] step: 200, loss: 0.13239526748657227, accuracy: 95.52284240722656
2024-07-10 15:24:32.398018: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
[train] step: 400, loss: 0.12522409856319427, accuracy: 96.515625
[test] step: 400, loss: 0.07021520286798477, accuracy: 97.8465576171875
2024-07-10 15:24:38.439548: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
[train] step: 600, loss: 0.09092658758163452, accuracy: 97.25
[test] step: 600, loss: 0.08268354833126068, accuracy: 97.30569458007812
2024-07-10 15:24:44.516602: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
[train] step: 800, loss: 0.07523862272500992, accuracy: 97.921875
[test] step: 800, loss: 0.060881033539772034, accuracy: 98.036865234375
2024-07-10 15:24:50.557494: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
[train] step: 1000, loss: 0.063808374106884, accuracy: 98.09375
[test] step: 1000, loss: 0.07719086110591888, accuracy: 97.4258804321289
2024-07-10 15:24:54.450444: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
[train] step: 1199, loss: 0.07750937342643738, accuracy: 97.47173309326172
[test] step: 1199, loss: 0.05415954813361168, accuracy: 98.32732391357422
2024-07-10 15:24:56.610632: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-07-10 15:24:56.615182: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence

7. 可視化指標#

使用 Matplotlib,您可以為損失和準確度建立圖表

import matplotlib.pyplot as plt  # Visualization

# Plot loss and accuracy in subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.set_title('Loss')
ax2.set_title('Accuracy')
for dataset in ('train', 'test'):
  ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
  ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
ax1.legend()
ax2.legend()
plt.show()
_images/d36a49b3d50ea3065ae6ac199fb1947b9020a24529772a8984f8286715da991e.png

10. 在測試集上執行推論#

建立一個 jit 編譯的模型推論函數(使用 nnx.jit) - pred_step - 使用學習到的模型參數在測試集上產生預測。這將使您可以將測試影像與其預測的標籤一起視覺化,以對模型效能進行定性評估。

model.eval() # Switch to evaluation mode.

@nnx.jit
def pred_step(model: CNN, batch):
  logits = model(batch['image'])
  return logits.argmax(axis=1)

請注意,我們使用 .eval() 來確保模型處於評估模式,即使我們在此模型中未使用 DropoutBatchNorm.eval() 仍確保輸出是確定性的。

test_batch = test_ds.as_numpy_iterator().next()
pred = pred_step(model, test_batch)

fig, axs = plt.subplots(5, 5, figsize=(12, 12))
for i, ax in enumerate(axs.flatten()):
  ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')
  ax.set_title(f'label={pred[i]}')
  ax.axis('off')
_images/489b1b86b5d822076b751c8f480bb3815838366937bbc4837bf06947daf42380.png

恭喜!您已學會如何使用 Flax NNX 在 MNIST 資料集上端對端地建構和訓練簡單的分類模型。

接下來,查看為什麼選擇 Flax NNX?,並開始使用一系列Flax NNX 指南