Open in Colab Open On GitHub

快速開始#

歡迎使用 Flax!

Flax 是在 JAX 基礎上建置的開源 Python 神經網路程式庫。本教學說明如何使用 Flax 的 Linen API 建構一個簡單的卷積神經網路 (CNN),並訓練該網路對 MNIST 資料集影像進行分類。

1. 安裝 Flax#

!pip install -q flax>=0.7.5

2. 載入資料#

Flax 可以使用任何資料載入管道,本範例說明如何使用 TFDS。定義一個載入和準備 MNIST 資料集並將範例轉換成浮點數的函式。

import tensorflow_datasets as tfds  # TFDS for MNIST
import tensorflow as tf             # TensorFlow operations

def get_datasets(num_epochs, batch_size):
  """Load MNIST train and test datasets into memory."""
  train_ds = tfds.load('mnist', split='train')
  test_ds = tfds.load('mnist', split='test')

  train_ds = train_ds.map(lambda sample: {'image': tf.cast(sample['image'],
                                                           tf.float32) / 255.,
                                          'label': sample['label']}) # normalize train set
  test_ds = test_ds.map(lambda sample: {'image': tf.cast(sample['image'],
                                                         tf.float32) / 255.,
                                        'label': sample['label']}) # normalize test set

  train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
  train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
  test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
  test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency

  return train_ds, test_ds

3. 定義網路#

透過繼承 Flax Module,使用 Linen API 建立一個卷積神經網路。由於本範例中的架構相對簡單(僅堆疊層),因此你可以直接在 __call__ 方法中定義內嵌的子模組,並使用 @compact 裝飾符來包覆它。如需進一步瞭解 Flax Linen @compact 裝飾符的資訊,請參閱 setupcompact 指南。

from flax import linen as nn  # Linen API

class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

檢視模型層#

建立 Flax Module 的執行個體,並使用 Module.tabulate 方法來透過傳遞 RNG 金鑰和範本圖片輸入的方式,視覺化模型層的表格。

import jax
import jax.numpy as jnp  # JAX NumPy

cnn = CNN()
print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1)),
                   compute_flops=True, compute_vjp_flops=True))
                                  CNN Summary                                   
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ path     module  inputs      outputs    flops    vjp_flops  params     ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━┩
│         │ CNN    │ float32[1… │ float32[… │ 8708106 │ 26957556  │            │
├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤
│ Conv_0  │ Conv   │ float32[1… │ float32[… │ 455424  │ 1341472   │ bias:      │
│         │        │            │           │         │           │ float32[3… │
│         │        │            │           │         │           │ kernel:    │
│         │        │            │           │         │           │ float32[3… │
│         │        │            │           │         │           │            │
│         │        │            │           │         │           │ 320 (1.3   │
│         │        │            │           │         │           │ KB)        │
├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤
│ Conv_1  │ Conv   │ float32[1… │ float32[… │ 6566144 │ 19704320  │ bias:      │
│         │        │            │           │         │           │ float32[6… │
│         │        │            │           │         │           │ kernel:    │
│         │        │            │           │         │           │ float32[3… │
│         │        │            │           │         │           │            │
│         │        │            │           │         │           │ 18,496     │
│         │        │            │           │         │           │ (74.0 KB)  │
├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤
│ Dense_0 │ Dense  │ float32[1… │ float32[… │ 1605888 │ 5620224   │ bias:      │
│         │        │            │           │         │           │ float32[2… │
│         │        │            │           │         │           │ kernel:    │
│         │        │            │           │         │           │ float32[3… │
│         │        │            │           │         │           │            │
│         │        │            │           │         │           │ 803,072    │
│         │        │            │           │         │           │ (3.2 MB)   │
├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤
│ Dense_1 │ Dense  │ float32[1… │ float32[… │ 5130    │ 17940     │ bias:      │
│         │        │            │           │         │           │ float32[1… │
│         │        │            │           │         │           │ kernel:    │
│         │        │            │           │         │           │ float32[2… │
│         │        │            │           │         │           │            │
│         │        │            │           │         │           │ 2,570      │
│         │        │            │           │         │           │ (10.3 KB)  │
├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤
│                                                      Total  824,458    │
│                                                             (3.3 MB)   │
└─────────┴────────┴────────────┴───────────┴─────────┴───────────┴────────────┘
                                                                                
                       Total Parameters: 824,458 (3.3 MB)                       

4. 建立 TrainState#

在 Flax 中,常見的模式是建立一個單一 dataclass,以表示整個訓練狀態,包括步驟數、參數和最佳化器狀態。

由於這是一種相當常見的模式,因此 Flax 提供了 flax.training.train_state.TrainState 類別,它服務於大部分的基本用例。

!pip install -q clu
from clu import metrics
from flax.training import train_state  # Useful dataclass to keep train state
from flax import struct                # Flax dataclasses
import optax                           # Common loss functions and optimizers

我們將使用 clu 函式庫進行度量計算。更多關於 clu 的資訊,請參閱 reponotebook

@struct.dataclass
class Metrics(metrics.Collection):
  accuracy: metrics.Accuracy
  loss: metrics.Average.from_output('loss')

然後,你可以建立 train_state.TrainState 的子類別,使其同時包含度量資訊。這樣的好處是,我們只需要將一個單一參數傳遞給函式(如 train_step(),見下方)即可同時計算損失值、更新參數和計算度量資訊。

class TrainState(train_state.TrainState):
  metrics: Metrics

def create_train_state(module, rng, learning_rate, momentum):
  """Creates an initial `TrainState`."""
  params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image
  tx = optax.sgd(learning_rate, momentum)
  return TrainState.create(
      apply_fn=module.apply, params=params, tx=tx,
      metrics=Metrics.empty())

5. 訓練步驟#

一個函式可

使用 JAX 的 @jit 裝飾器追蹤整個 train_step 函式,並使用 XLA 即時編譯成融合裝置操作,以便在硬體加速器上更快、更有效率地執行。

@jax.jit
def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = state.apply_fn({'params': params}, batch['image'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch['label']).mean()
    return loss
  grad_fn = jax.grad(loss_fn)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

6. 度量計算#

為損失值和準確度度量建立一個個別函式。損失值使用 optax.softmax_cross_entropy_with_integer_labels 函式進行計算,而準確度則使用 clu.metrics 進行計算。

@jax.jit
def compute_metrics(*, state, batch):
  logits = state.apply_fn({'params': state.params}, batch['image'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch['label']).mean()
  metric_updates = state.metrics.single_from_model_output(
    logits=logits, labels=batch['label'], loss=loss)
  metrics = state.metrics.merge(metric_updates)
  state = state.replace(metrics=metrics)
  return state

7. 下載資料#

num_epochs = 10
batch_size = 32

train_ds, test_ds = get_datasets(num_epochs, batch_size)

8. 設定隨機種子#

  • 設定 TF 隨機種子以確保資料集洗牌(使用 tf.data.Dataset.shuffle)是可以複製的。

  • 取得一個 PRNGKey 並將其用於參數初始化。(進一步了解 JAX PRNG 設計PRNG 鏈。)

tf.random.set_seed(0)
init_rng = jax.random.key(0)

9. 初始化 TrainState#

請記住,函式 create_train_state 會初始化模型參數、最佳化器和指標,並將其放入傳回的訓練狀態資料類別中。

learning_rate = 0.01
momentum = 0.9
state = create_train_state(cnn, init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

10. 訓練與評量#

透過下列步驟建立「洗牌」的資料集:

  • 反覆播放資料集,等於訓練歷代的數量

  • 配置大小為 1024 的緩衝區(包含資料集中的前 1024 個樣本),用於隨機抽樣批次

    • 每次從緩衝區隨機繪製樣本時,會將資料集中的下一個樣本載入緩衝區

定義一個訓練迴圈,其:

  • 從資料集中隨機抽樣批次。

  • 對每個訓練批次執行最佳化步驟。

  • 計算一個世代中每個批次上的平均訓練指標。

  • 使用更新的參數計算測試集的指標。

  • 記錄訓練和測試指標以進行視覺化。

在完成 10 個世代的訓練和測試後,輸出應顯示您的模型能夠達到約 99% 的準確率。

# since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs
num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs
metrics_history = {'train_loss': [],
                   'train_accuracy': [],
                   'test_loss': [],
                   'test_accuracy': []}
for step,batch in enumerate(train_ds.as_numpy_iterator()):

  # Run optimization steps over training batches and compute batch metrics
  state = train_step(state, batch) # get updated train state (which contains the updated parameters)
  state = compute_metrics(state=state, batch=batch) # aggregate batch metrics

  if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed
    for metric,value in state.metrics.compute().items(): # compute metrics
      metrics_history[f'train_{metric}'].append(value) # record metrics
    state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch

    # Compute metrics on the test set after each training epoch
    test_state = state
    for test_batch in test_ds.as_numpy_iterator():
      test_state = compute_metrics(state=test_state, batch=test_batch)

    for metric,value in test_state.metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)

    print(f"train epoch: {(step+1) // num_steps_per_epoch}, "
          f"loss: {metrics_history['train_loss'][-1]}, "
          f"accuracy: {metrics_history['train_accuracy'][-1] * 100}")
    print(f"test epoch: {(step+1) // num_steps_per_epoch}, "
          f"loss: {metrics_history['test_loss'][-1]}, "
          f"accuracy: {metrics_history['test_accuracy'][-1] * 100}")
train epoch: 1, loss: 0.20290373265743256, accuracy: 93.87000274658203
test epoch: 1, loss: 0.07591685652732849, accuracy: 97.60617065429688
train epoch: 2, loss: 0.05760224163532257, accuracy: 98.28500366210938
test epoch: 2, loss: 0.050395529717206955, accuracy: 98.3974380493164
train epoch: 3, loss: 0.03897436335682869, accuracy: 98.83000183105469
test epoch: 3, loss: 0.04574578255414963, accuracy: 98.54767608642578
train epoch: 4, loss: 0.028721099719405174, accuracy: 99.15166473388672
test epoch: 4, loss: 0.035722777247428894, accuracy: 98.91827392578125
train epoch: 5, loss: 0.021948494017124176, accuracy: 99.37999725341797
test epoch: 5, loss: 0.035723842680454254, accuracy: 98.87820434570312
train epoch: 6, loss: 0.01705147698521614, accuracy: 99.54833221435547
test epoch: 6, loss: 0.03456473350524902, accuracy: 98.96835327148438
train epoch: 7, loss: 0.014007646590471268, accuracy: 99.6116714477539
test epoch: 7, loss: 0.04089202359318733, accuracy: 98.7880630493164
train epoch: 8, loss: 0.011265480890870094, accuracy: 99.73333740234375
test epoch: 8, loss: 0.03337760642170906, accuracy: 98.93830108642578
train epoch: 9, loss: 0.00918484665453434, accuracy: 99.78334045410156
test epoch: 9, loss: 0.034478139132261276, accuracy: 98.96835327148438
train epoch: 10, loss: 0.007260234095156193, accuracy: 99.84166717529297
test epoch: 10, loss: 0.032822880893945694, accuracy: 99.07852172851562

11. 視覺化指標#

import matplotlib.pyplot as plt  # Visualization

# Plot loss and accuracy in subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.set_title('Loss')
ax2.set_title('Accuracy')
for dataset in ('train','test'):
  ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
  ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
ax1.legend()
ax2.legend()
plt.show()
plt.clf()
_images/281863de2467b2eb19baff419ebec86e88a215a30b1a1691ebea188c999c09f8.png
<Figure size 600x400 with 0 Axes>

12. 在測試集上執行推論#

定義一個 JIT 化的推論函式 pred_step。使用已學習的參數對測試集執行模型推論,並視覺化影像及其對應的預測標籤。

@jax.jit
def pred_step(state, batch):
  logits = state.apply_fn({'params': state.params}, test_batch['image'])
  return logits.argmax(axis=1)

test_batch = test_ds.as_numpy_iterator().next()
pred = pred_step(state, test_batch)
fig, axs = plt.subplots(5, 5, figsize=(12, 12))
for i, ax in enumerate(axs.flatten()):
    ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')
    ax.set_title(f"label={pred[i]}")
    ax.axis('off')
_images/c6bd7e9d04a64d28db87cb3764a6cfd03dd7c476ae7a1bd7650e427e6a3632ce.png

恭喜!您已完成附註的 MNIST 範例。您可以在 Flax 的 Git 回存庫中,以不同結構再瀏覽同一範例,但可以是幾個 Python 模組、測試模組、設定檔、另一個 Colab 和文件。

google/flax