快速開始#
歡迎使用 Flax!
Flax 是在 JAX 基礎上建置的開源 Python 神經網路程式庫。本教學說明如何使用 Flax 的 Linen API 建構一個簡單的卷積神經網路 (CNN),並訓練該網路對 MNIST 資料集影像進行分類。
1. 安裝 Flax#
!pip install -q flax>=0.7.5
2. 載入資料#
Flax 可以使用任何資料載入管道,本範例說明如何使用 TFDS。定義一個載入和準備 MNIST 資料集並將範例轉換成浮點數的函式。
import tensorflow_datasets as tfds # TFDS for MNIST
import tensorflow as tf # TensorFlow operations
def get_datasets(num_epochs, batch_size):
"""Load MNIST train and test datasets into memory."""
train_ds = tfds.load('mnist', split='train')
test_ds = tfds.load('mnist', split='test')
train_ds = train_ds.map(lambda sample: {'image': tf.cast(sample['image'],
tf.float32) / 255.,
'label': sample['label']}) # normalize train set
test_ds = test_ds.map(lambda sample: {'image': tf.cast(sample['image'],
tf.float32) / 255.,
'label': sample['label']}) # normalize test set
train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
return train_ds, test_ds
3. 定義網路#
透過繼承 Flax Module,使用 Linen API 建立一個卷積神經網路。由於本範例中的架構相對簡單(僅堆疊層),因此你可以直接在 __call__
方法中定義內嵌的子模組,並使用 @compact
裝飾符來包覆它。如需進一步瞭解 Flax Linen @compact
裝飾符的資訊,請參閱 setup
與 compact
指南。
from flax import linen as nn # Linen API
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
檢視模型層#
建立 Flax Module 的執行個體,並使用 Module.tabulate
方法來透過傳遞 RNG 金鑰和範本圖片輸入的方式,視覺化模型層的表格。
import jax
import jax.numpy as jnp # JAX NumPy
cnn = CNN()
print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1)),
compute_flops=True, compute_vjp_flops=True))
CNN Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ flops ┃ vjp_flops ┃ params ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ │ CNN │ float32[1… │ float32[… │ 8708106 │ 26957556 │ │
├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤
│ Conv_0 │ Conv │ float32[1… │ float32[… │ 455424 │ 1341472 │ bias: │
│ │ │ │ │ │ │ float32[3… │
│ │ │ │ │ │ │ kernel: │
│ │ │ │ │ │ │ float32[3… │
│ │ │ │ │ │ │ │
│ │ │ │ │ │ │ 320 (1.3 │
│ │ │ │ │ │ │ KB) │
├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤
│ Conv_1 │ Conv │ float32[1… │ float32[… │ 6566144 │ 19704320 │ bias: │
│ │ │ │ │ │ │ float32[6… │
│ │ │ │ │ │ │ kernel: │
│ │ │ │ │ │ │ float32[3… │
│ │ │ │ │ │ │ │
│ │ │ │ │ │ │ 18,496 │
│ │ │ │ │ │ │ (74.0 KB) │
├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤
│ Dense_0 │ Dense │ float32[1… │ float32[… │ 1605888 │ 5620224 │ bias: │
│ │ │ │ │ │ │ float32[2… │
│ │ │ │ │ │ │ kernel: │
│ │ │ │ │ │ │ float32[3… │
│ │ │ │ │ │ │ │
│ │ │ │ │ │ │ 803,072 │
│ │ │ │ │ │ │ (3.2 MB) │
├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤
│ Dense_1 │ Dense │ float32[1… │ float32[… │ 5130 │ 17940 │ bias: │
│ │ │ │ │ │ │ float32[1… │
│ │ │ │ │ │ │ kernel: │
│ │ │ │ │ │ │ float32[2… │
│ │ │ │ │ │ │ │
│ │ │ │ │ │ │ 2,570 │
│ │ │ │ │ │ │ (10.3 KB) │
├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤
│ │ │ │ │ │ Total │ 824,458 │
│ │ │ │ │ │ │ (3.3 MB) │
└─────────┴────────┴────────────┴───────────┴─────────┴───────────┴────────────┘
Total Parameters: 824,458 (3.3 MB)
4. 建立 TrainState
#
在 Flax 中,常見的模式是建立一個單一 dataclass,以表示整個訓練狀態,包括步驟數、參數和最佳化器狀態。
由於這是一種相當常見的模式,因此 Flax 提供了 flax.training.train_state.TrainState
類別,它服務於大部分的基本用例。
!pip install -q clu
from clu import metrics
from flax.training import train_state # Useful dataclass to keep train state
from flax import struct # Flax dataclasses
import optax # Common loss functions and optimizers
我們將使用 clu
函式庫進行度量計算。更多關於 clu
的資訊,請參閱 repo 和 notebook。
@struct.dataclass
class Metrics(metrics.Collection):
accuracy: metrics.Accuracy
loss: metrics.Average.from_output('loss')
然後,你可以建立 train_state.TrainState
的子類別,使其同時包含度量資訊。這樣的好處是,我們只需要將一個單一參數傳遞給函式(如 train_step()
,見下方)即可同時計算損失值、更新參數和計算度量資訊。
class TrainState(train_state.TrainState):
metrics: Metrics
def create_train_state(module, rng, learning_rate, momentum):
"""Creates an initial `TrainState`."""
params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image
tx = optax.sgd(learning_rate, momentum)
return TrainState.create(
apply_fn=module.apply, params=params, tx=tx,
metrics=Metrics.empty())
5. 訓練步驟#
一個函式可
對照給定的參數和一批輸入圖像評估神經網路,使用
TrainState.apply_fn
(包含Module.apply
方法(正向遞送)。使用預先定義的
optax.softmax_cross_entropy_with_integer_labels()
計算交叉熵損失值。請注意,此函式預期整數量標籤,因此無需將標籤轉換為 one-hot 編碼。使用
jax.grad
評估損失函式的梯度。將 pytree 的梯度套用至最佳化器,以更新模型參數。
使用 JAX 的 @jit 裝飾器追蹤整個 train_step
函式,並使用 XLA 即時編譯成融合裝置操作,以便在硬體加速器上更快、更有效率地執行。
@jax.jit
def train_step(state, batch):
"""Train for a single step."""
def loss_fn(params):
logits = state.apply_fn({'params': params}, batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']).mean()
return loss
grad_fn = jax.grad(loss_fn)
grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
6. 度量計算#
為損失值和準確度度量建立一個個別函式。損失值使用 optax.softmax_cross_entropy_with_integer_labels
函式進行計算,而準確度則使用 clu.metrics
進行計算。
@jax.jit
def compute_metrics(*, state, batch):
logits = state.apply_fn({'params': state.params}, batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']).mean()
metric_updates = state.metrics.single_from_model_output(
logits=logits, labels=batch['label'], loss=loss)
metrics = state.metrics.merge(metric_updates)
state = state.replace(metrics=metrics)
return state
7. 下載資料#
num_epochs = 10
batch_size = 32
train_ds, test_ds = get_datasets(num_epochs, batch_size)
8. 設定隨機種子#
設定 TF 隨機種子以確保資料集洗牌(使用
tf.data.Dataset.shuffle
)是可以複製的。取得一個 PRNGKey 並將其用於參數初始化。(進一步了解 JAX PRNG 設計 和 PRNG 鏈。)
tf.random.set_seed(0)
init_rng = jax.random.key(0)
9. 初始化 TrainState
#
請記住,函式 create_train_state
會初始化模型參數、最佳化器和指標,並將其放入傳回的訓練狀態資料類別中。
learning_rate = 0.01
momentum = 0.9
state = create_train_state(cnn, init_rng, learning_rate, momentum)
del init_rng # Must not be used anymore.
10. 訓練與評量#
透過下列步驟建立「洗牌」的資料集:
反覆播放資料集,等於訓練歷代的數量
配置大小為 1024 的緩衝區(包含資料集中的前 1024 個樣本),用於隨機抽樣批次
每次從緩衝區隨機繪製樣本時,會將資料集中的下一個樣本載入緩衝區
定義一個訓練迴圈,其:
從資料集中隨機抽樣批次。
對每個訓練批次執行最佳化步驟。
計算一個世代中每個批次上的平均訓練指標。
使用更新的參數計算測試集的指標。
記錄訓練和測試指標以進行視覺化。
在完成 10 個世代的訓練和測試後,輸出應顯示您的模型能夠達到約 99% 的準確率。
# since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs
num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs
metrics_history = {'train_loss': [],
'train_accuracy': [],
'test_loss': [],
'test_accuracy': []}
for step,batch in enumerate(train_ds.as_numpy_iterator()):
# Run optimization steps over training batches and compute batch metrics
state = train_step(state, batch) # get updated train state (which contains the updated parameters)
state = compute_metrics(state=state, batch=batch) # aggregate batch metrics
if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed
for metric,value in state.metrics.compute().items(): # compute metrics
metrics_history[f'train_{metric}'].append(value) # record metrics
state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch
# Compute metrics on the test set after each training epoch
test_state = state
for test_batch in test_ds.as_numpy_iterator():
test_state = compute_metrics(state=test_state, batch=test_batch)
for metric,value in test_state.metrics.compute().items():
metrics_history[f'test_{metric}'].append(value)
print(f"train epoch: {(step+1) // num_steps_per_epoch}, "
f"loss: {metrics_history['train_loss'][-1]}, "
f"accuracy: {metrics_history['train_accuracy'][-1] * 100}")
print(f"test epoch: {(step+1) // num_steps_per_epoch}, "
f"loss: {metrics_history['test_loss'][-1]}, "
f"accuracy: {metrics_history['test_accuracy'][-1] * 100}")
train epoch: 1, loss: 0.20290373265743256, accuracy: 93.87000274658203
test epoch: 1, loss: 0.07591685652732849, accuracy: 97.60617065429688
train epoch: 2, loss: 0.05760224163532257, accuracy: 98.28500366210938
test epoch: 2, loss: 0.050395529717206955, accuracy: 98.3974380493164
train epoch: 3, loss: 0.03897436335682869, accuracy: 98.83000183105469
test epoch: 3, loss: 0.04574578255414963, accuracy: 98.54767608642578
train epoch: 4, loss: 0.028721099719405174, accuracy: 99.15166473388672
test epoch: 4, loss: 0.035722777247428894, accuracy: 98.91827392578125
train epoch: 5, loss: 0.021948494017124176, accuracy: 99.37999725341797
test epoch: 5, loss: 0.035723842680454254, accuracy: 98.87820434570312
train epoch: 6, loss: 0.01705147698521614, accuracy: 99.54833221435547
test epoch: 6, loss: 0.03456473350524902, accuracy: 98.96835327148438
train epoch: 7, loss: 0.014007646590471268, accuracy: 99.6116714477539
test epoch: 7, loss: 0.04089202359318733, accuracy: 98.7880630493164
train epoch: 8, loss: 0.011265480890870094, accuracy: 99.73333740234375
test epoch: 8, loss: 0.03337760642170906, accuracy: 98.93830108642578
train epoch: 9, loss: 0.00918484665453434, accuracy: 99.78334045410156
test epoch: 9, loss: 0.034478139132261276, accuracy: 98.96835327148438
train epoch: 10, loss: 0.007260234095156193, accuracy: 99.84166717529297
test epoch: 10, loss: 0.032822880893945694, accuracy: 99.07852172851562
11. 視覺化指標#
import matplotlib.pyplot as plt # Visualization
# Plot loss and accuracy in subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.set_title('Loss')
ax2.set_title('Accuracy')
for dataset in ('train','test'):
ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
ax1.legend()
ax2.legend()
plt.show()
plt.clf()

<Figure size 600x400 with 0 Axes>
12. 在測試集上執行推論#
定義一個 JIT 化的推論函式 pred_step
。使用已學習的參數對測試集執行模型推論,並視覺化影像及其對應的預測標籤。
@jax.jit
def pred_step(state, batch):
logits = state.apply_fn({'params': state.params}, test_batch['image'])
return logits.argmax(axis=1)
test_batch = test_ds.as_numpy_iterator().next()
pred = pred_step(state, test_batch)
fig, axs = plt.subplots(5, 5, figsize=(12, 12))
for i, ax in enumerate(axs.flatten()):
ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')
ax.set_title(f"label={pred[i]}")
ax.axis('off')

恭喜!您已完成附註的 MNIST 範例。您可以在 Flax 的 Git 回存庫中,以不同結構再瀏覽同一範例,但可以是幾個 Python 模組、測試模組、設定檔、另一個 Colab 和文件。