RNNCellBase 升級指南

RNNCellBase 升級指南#

為了提升RNNCellBase API 的可用性,該 API 進行了以下幾個關鍵更新

  • initialize_carry 方法已從類別方法轉換為執行個體方法,簡化了其應用方式。

  • 所有必要元資料現在都直接儲存在儲存格執行個體內,提供了簡化的方法簽章。

本指南將帶您逐步了解這些變更,說明如何更新現有程式碼以符合這些增強功能。

基本用法#

先來定義一些變數和代表一批順序的範例輸入

batch_size = 32
seq_len = 10
in_features = 64
out_features = 128

x = jnp.ones((batch_size, seq_len, in_features))

最重要的是,所有元資料(包含特徵數、承載初始化程式等等)現在都儲存在儲存格實例中

cell = nn.LSTMCell()
cell = nn.LSTMCell(features=out_features)

一個重大的變更為,initialize_carry 已轉換為一個實例方法。由於儲存格實例現在包含所有元資料,initialize_carry 方法的簽章只要 PRNG 鍵和一個範例輸入

carry = nn.LSTMCell.initialize_carry(jax.random.key(0), (batch_size,), out_features)
carry = cell.initialize_carry(jax.random.key(0), x[:, 0].shape)

在此,x[:, 0].shape 代表儲存格的輸入(不含時間維度)。在更方便的時候,您也可以直接建立輸入形狀

carry = cell.initialize_carry(jax.random.key(0), (batch_size, in_features))

升級模式#

以下各節將說明一些將程式碼更新為符合新 API 的有用模式。

首先,我們將顯示如何升級一個包含儲存格的模組,在__call__期間套用掃描邏輯,並具備靜態initialize_carry方法。在此,我們將盡可能減少對程式碼的變更,以使程式碼能順利運作,即使不是最慣用的方式

class SimpleLSTM(nn.Module):

  @functools.partial(
    nn.transforms.scan,
    variable_broadcast='params',
    in_axes=1, out_axes=1,
    split_rngs={'params': False})
  @nn.compact
  def __call__(self, carry, x):

    return nn.OptimizedLSTMCell()(carry, x)

  @staticmethod
  def initialize_carry(batch_dims, hidden_size):
    return nn.OptimizedLSTMCell.initialize_carry(
      jax.random.key(0), batch_dims, hidden_size)
class SimpleLSTM(nn.Module):

  @functools.partial(
    nn.transforms.scan,
    variable_broadcast='params',
    in_axes=1, out_axes=1,
    split_rngs={'params': False})
  @nn.compact
  def __call__(self, carry, x):
    features = carry[0].shape[-1]
    return nn.OptimizedLSTMCell(features)(carry, x)

  @staticmethod
  def initialize_carry(batch_dims, hidden_size):
    return nn.OptimizedLSTMCell(hidden_size, parent=None).initialize_carry(
      jax.random.key(0), (*batch_dims, hidden_size))

請注意,在新版中,我們必須從 __call__ 期間的承載中擷取特徵數,並在initialize_carry期間使用parent=None,以避免潛在的副作用。

接下來,我們將顯示一種撰寫類似 LSTM 模組的慣用方式。這裡的主要變更將新增一個特徵屬性在模組中,並使用它在setup方法中初始化一個nn.scan-ed 版本的儲存格

class SimpleLSTM(nn.Module):

  @functools.partial(
    nn.transforms.scan,
    variable_broadcast='params',
    in_axes=1, out_axes=1,
    split_rngs={'params': False})
  @nn.compact
  def __call__(self, carry, x):
    return nn.OptimizedLSTMCell()(carry, x)

  @staticmethod
  def initialize_carry(batch_dims, hidden_size):
    return nn.OptimizedLSTMCell.initialize_carry(
      jax.random.key(0), batch_dims, hidden_size)

model = SimpleLSTM()
carry = SimpleLSTM.initialize_carry((batch_size,), out_features)
variables = model.init(jax.random.key(0), carry, x)
class SimpleLSTM(nn.Module):
  features: int

  def setup(self):
    self.scan_cell = nn.transforms.scan(
      nn.OptimizedLSTMCell,
      variable_broadcast='params',
      in_axes=1, out_axes=1,
      split_rngs={'params': False})(self.features)


  @nn.compact
  def __call__(self, x):
    carry = self.scan_cell.initialize_carry(jax.random.key(0), x[:, 0].shape)
    return self.scan_cell(carry, x)[1]  # only return the output


model = SimpleLSTM(features=out_features)
variables = model.init(jax.random.key(0), x)

由於承載可以從範例輸入中輕鬆初始化,因此我們可以將initialize_carry的呼叫移動到__call__方法中,這可以簡化程式碼。

開發備註#

開發新儲存格時,請考慮下列事項:

  • 將必要的元資料包含為實例屬性。

  • initialize_carry 現在只需要 PRNG 金鑰和樣本輸入。

  • 新的 num_feature_axes 屬性是必需的,用來指定特徵度量。

class LSTMCell(nn.RNNCellBase):
  features: int # ← All metadata is now stored within the cell instance
  ... #              ↓
  carry_init: Initializer

  def initialize_carry(self, rng, input_shape) -> Carry:
    ...

  @property
  def num_feature_axes(self):
    return 1

num_feature_axes 是新的 API 特性,它允許處理隨意 RNNCellBase 執行個體的程式碼,例如 RNN 模組,以推斷批次度量並決定時間軸的位置。