在多個裝置上組成#
我們展示如何在 MNIST 資料集上訓練 CNN 的集合,其中集合的大小等於可用裝置的數量,簡單來說,這項變更的描述為
使用
jax.pmap()
使多個函式平行將隨機種子分割以取得不同的參數初始化,
必要的複製輸入和不複製輸出,
在設備間平均機率以計算預測。
在這個簡明教學中,我們會省略部分程式碼,例如導入、CNN 模組和度量計算,但可以在 MNIST 範例 中找到這些程式碼。
平行函式#
我們先建立 create\_train\_state()
的平行版本,它會擷取模型的初始參數。我們使用 jax.pmap()
來執行這個動作。映射函數的效用是它會使用 XLA 編譯函數(類似於 jax.jit()
),但在 XLA 設備(例如 GPU/TPU)上以平行方式執行函數。
def create_train_state(rng, learning_rate, momentum):
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))
def create_train_state(rng, learning_rate, momentum):
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
請注意,對於上述單一模型程式碼,我們使用 jax.jit()
延遲初始化模型(更詳細的資訊,請參閱 Module.init 文件)。對於合成案例,jax.pmap()
預設會對提供的引數 rng
的第一個軸向進行對應,所以我們應確定在稍後呼叫此函數時為每個設備提供不同的值。
另外,請注意我們如何指定 learning\_rate
和 momentum
是靜態引數,這表示會使用這些引數的具體值,而非抽象形狀。這有其必要性,因為提供的引數會是純量值。更詳細的資訊,請參閱 JIT 機制:追蹤和靜態變數。
接下來我們對 apply_model()
和 update_model()
函數也執行相同的動作。為了從合奏中計算預測,我們取個別機率的平均值。我們使用 jax.lax.pmean()
來計算不同裝置平均值。這也要求我們對 jax.pmap()
和 jax.lax.pmean()
設定 axis_name
。
@jax.jit
def apply_model(state, images, labels):
def loss_fn(params):
logits = CNN().apply({'params': params}, images)
one_hot = jax.nn.one_hot(labels, 10)
loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return grads, loss, accuracy
@jax.jit
def update_model(state, grads):
return state.apply_gradients(grads=grads)
@functools.partial(jax.pmap, axis_name='ensemble')
def apply_model(state, images, labels):
def loss_fn(params):
logits = CNN().apply({'params': params}, images)
one_hot = jax.nn.one_hot(labels, 10)
loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name='ensemble')
accuracy = jnp.mean(jnp.argmax(probs, -1) == labels)
return grads, loss, accuracy
@jax.pmap
def update_model(state, grads):
return state.apply_gradients(grads=grads)
訓練合奏#
接著我們轉換 train_epoch()
函數。在從上方呼叫 pmap 映射函數時,我們主要需要考量複製所有裝置所需的參數,並取消重複回傳值。
def train_epoch(state, train_ds, batch_size, rng):
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size]
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_images = train_ds['image'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
state = update_model(state, grads)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
train_loss = np.mean(epoch_loss)
train_accuracy = np.mean(epoch_accuracy)
return state, train_loss, train_accuracy
def train_epoch(state, train_ds, batch_size, rng):
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size]
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_images = jax_utils.replicate(train_ds['image'][perm, ...])
batch_labels = jax_utils.replicate(train_ds['label'][perm, ...])
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
state = update_model(state, grads)
epoch_loss.append(jax_utils.unreplicate(loss))
epoch_accuracy.append(jax_utils.unreplicate(accuracy))
train_loss = np.mean(epoch_loss)
train_accuracy = np.mean(epoch_accuracy)
return state, train_loss, train_accuracy
正如您所見,我們不必對 state
附近的邏輯執行任何變更。這是因為,正如我們稍後會在訓練程式碼中看到,訓練狀態已經複製,因此當我們將其傳遞給 train_step()
時,一切都會順利運作,因為 train_step()
已被 pmap 映射。但是,訓練資料集尚未複製,因此我們在此執行此動作。由於複製整個訓練資料集會佔用太多記憶體,因此我們在批次層級執行此動作。
我們現在可以重寫實際的訓練邏輯。這包括兩個簡單的變更:確保將 RNG 在傳遞給 create_train_state()
時複製,以及複製測試資料集(比訓練資料集小很多,因此我們可以針對整個資料集直接執行此動作)。
train_ds, test_ds = get_datasets()
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, learning_rate, momentum)
for epoch in range(1, num_epochs + 1):
rng, input_rng = jax.random.split(rng)
state, train_loss, train_accuracy = train_epoch(
state, train_ds, batch_size, input_rng)
_, test_loss, test_accuracy = apply_model(
state, test_ds['image'], test_ds['label'])
logging.info(
'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, '
'test_loss: %.4f, test_accuracy: %.2f'
% (epoch, train_loss, train_accuracy * 100, test_loss,
test_accuracy * 100))
train_ds, test_ds = get_datasets()
test_ds = jax_utils.replicate(test_ds)
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
state = create_train_state(jax.random.split(init_rng, jax.device_count()),
learning_rate, momentum)
for epoch in range(1, num_epochs + 1):
rng, input_rng = jax.random.split(rng)
state, train_loss, train_accuracy = train_epoch(
state, train_ds, batch_size, input_rng)
_, test_loss, test_accuracy = jax_utils.unreplicate(
apply_model(state, test_ds['image'], test_ds['label']))
logging.info(
'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, '
'test_loss: %.4f, test_accuracy: %.2f'
% (epoch, train_loss, train_accuracy * 100, test_loss,
test_accuracy * 100))