組成不同的裝置

在多個裝置上組成#

我們展示如何在 MNIST 資料集上訓練 CNN 的集合,其中集合的大小等於可用裝置的數量,簡單來說,這項變更的描述為

  • 使用 jax.pmap() 使多個函式平行

  • 將隨機種子分割以取得不同的參數初始化,

  • 必要的複製輸入和不複製輸出,

  • 在設備間平均機率以計算預測。

在這個簡明教學中,我們會省略部分程式碼,例如導入、CNN 模組和度量計算,但可以在 MNIST 範例 中找到這些程式碼。

平行函式#

我們先建立 create\_train\_state() 的平行版本,它會擷取模型的初始參數。我們使用 jax.pmap() 來執行這個動作。映射函數的效用是它會使用 XLA 編譯函數(類似於 jax.jit()),但在 XLA 設備(例如 GPU/TPU)上以平行方式執行函數。

def create_train_state(rng, learning_rate, momentum):
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)
@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))
def create_train_state(rng, learning_rate, momentum):
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

請注意,對於上述單一模型程式碼,我們使用 jax.jit() 延遲初始化模型(更詳細的資訊,請參閱 Module.init 文件)。對於合成案例,jax.pmap() 預設會對提供的引數 rng 的第一個軸向進行對應,所以我們應確定在稍後呼叫此函數時為每個設備提供不同的值。

另外,請注意我們如何指定 learning\_ratemomentum 是靜態引數,這表示會使用這些引數的具體值,而非抽象形狀。這有其必要性,因為提供的引數會是純量值。更詳細的資訊,請參閱 JIT 機制:追蹤和靜態變數

接下來我們對 apply_model()update_model() 函數也執行相同的動作。為了從合奏中計算預測,我們取個別機率的平均值。我們使用 jax.lax.pmean() 來計算不同裝置平均值。這也要求我們對 jax.pmap()jax.lax.pmean() 設定 axis_name

@jax.jit
def apply_model(state, images, labels):
  def loss_fn(params):
    logits = CNN().apply({'params': params}, images)
    one_hot = jax.nn.one_hot(labels, 10)
    loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)

  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return grads, loss, accuracy

@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)
@functools.partial(jax.pmap, axis_name='ensemble')
def apply_model(state, images, labels):
  def loss_fn(params):
    logits = CNN().apply({'params': params}, images)
    one_hot = jax.nn.one_hot(labels, 10)
    loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name='ensemble')
  accuracy = jnp.mean(jnp.argmax(probs, -1) == labels)
  return grads, loss, accuracy

@jax.pmap
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

訓練合奏#

接著我們轉換 train_epoch() 函數。在從上方呼叫 pmap 映射函數時,我們主要需要考量複製所有裝置所需的參數,並取消重複回傳值。

def train_epoch(state, train_ds, batch_size, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]
  perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []
  epoch_accuracy = []

  for perm in perms:
    batch_images = train_ds['image'][perm, ...]
    batch_labels = train_ds['label'][perm, ...]
    grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
    state = update_model(state, grads)
    epoch_loss.append(loss)
    epoch_accuracy.append(accuracy)
  train_loss = np.mean(epoch_loss)
  train_accuracy = np.mean(epoch_accuracy)
  return state, train_loss, train_accuracy
def train_epoch(state, train_ds, batch_size, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]
  perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []
  epoch_accuracy = []

  for perm in perms:
    batch_images = jax_utils.replicate(train_ds['image'][perm, ...])
    batch_labels = jax_utils.replicate(train_ds['label'][perm, ...])
    grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
    state = update_model(state, grads)
    epoch_loss.append(jax_utils.unreplicate(loss))
    epoch_accuracy.append(jax_utils.unreplicate(accuracy))
  train_loss = np.mean(epoch_loss)
  train_accuracy = np.mean(epoch_accuracy)
  return state, train_loss, train_accuracy

正如您所見,我們不必對 state 附近的邏輯執行任何變更。這是因為,正如我們稍後會在訓練程式碼中看到,訓練狀態已經複製,因此當我們將其傳遞給 train_step() 時,一切都會順利運作,因為 train_step() 已被 pmap 映射。但是,訓練資料集尚未複製,因此我們在此執行此動作。由於複製整個訓練資料集會佔用太多記憶體,因此我們在批次層級執行此動作。

我們現在可以重寫實際的訓練邏輯。這包括兩個簡單的變更:確保將 RNG 在傳遞給 create_train_state() 時複製,以及複製測試資料集(比訓練資料集小很多,因此我們可以針對整個資料集直接執行此動作)。

train_ds, test_ds = get_datasets()

rng = jax.random.key(0)

rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, learning_rate, momentum)


for epoch in range(1, num_epochs + 1):
  rng, input_rng = jax.random.split(rng)
  state, train_loss, train_accuracy = train_epoch(
      state, train_ds, batch_size, input_rng)

  _, test_loss, test_accuracy = apply_model(
      state, test_ds['image'], test_ds['label'])

  logging.info(
      'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, '
      'test_loss: %.4f, test_accuracy: %.2f'
      % (epoch, train_loss, train_accuracy * 100, test_loss,
         test_accuracy * 100))
train_ds, test_ds = get_datasets()
test_ds = jax_utils.replicate(test_ds)
rng = jax.random.key(0)

rng, init_rng = jax.random.split(rng)
state = create_train_state(jax.random.split(init_rng, jax.device_count()),
                           learning_rate, momentum)

for epoch in range(1, num_epochs + 1):
  rng, input_rng = jax.random.split(rng)
  state, train_loss, train_accuracy = train_epoch(
      state, train_ds, batch_size, input_rng)

  _, test_loss, test_accuracy = jax_utils.unreplicate(
      apply_model(state, test_ds['image'], test_ds['label']))

  logging.info(
      'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, '
      'test_loss: %.4f, test_accuracy: %.2f'
      % (epoch, train_loss, train_accuracy * 100, test_loss,
         test_accuracy * 100))