RNNCellBase 升級指南#
為了提升RNNCellBase
API 的可用性,該 API 進行了以下幾個關鍵更新
initialize_carry
方法已從類別方法轉換為執行個體方法,簡化了其應用方式。所有必要元資料現在都直接儲存在儲存格執行個體內,提供了簡化的方法簽章。
本指南將帶您逐步了解這些變更,說明如何更新現有程式碼以符合這些增強功能。
基本用法#
先來定義一些變數和代表一批順序的範例輸入
batch_size = 32
seq_len = 10
in_features = 64
out_features = 128
x = jnp.ones((batch_size, seq_len, in_features))
最重要的是,所有元資料(包含特徵數、承載初始化程式等等)現在都儲存在儲存格實例中
cell = nn.LSTMCell()
cell = nn.LSTMCell(features=out_features)
一個重大的變更為,initialize_carry
已轉換為一個實例方法。由於儲存格實例現在包含所有元資料,initialize_carry
方法的簽章只要 PRNG 鍵和一個範例輸入
carry = nn.LSTMCell.initialize_carry(jax.random.key(0), (batch_size,), out_features)
carry = cell.initialize_carry(jax.random.key(0), x[:, 0].shape)
在此,x[:, 0].shape
代表儲存格的輸入(不含時間維度)。在更方便的時候,您也可以直接建立輸入形狀
carry = cell.initialize_carry(jax.random.key(0), (batch_size, in_features))
升級模式#
以下各節將說明一些將程式碼更新為符合新 API 的有用模式。
首先,我們將顯示如何升級一個包含儲存格的模組
,在__call__
期間套用掃描邏輯,並具備靜態initialize_carry
方法。在此,我們將盡可能減少對程式碼的變更,以使程式碼能順利運作,即使不是最慣用的方式
class SimpleLSTM(nn.Module):
@functools.partial(
nn.transforms.scan,
variable_broadcast='params',
in_axes=1, out_axes=1,
split_rngs={'params': False})
@nn.compact
def __call__(self, carry, x):
return nn.OptimizedLSTMCell()(carry, x)
@staticmethod
def initialize_carry(batch_dims, hidden_size):
return nn.OptimizedLSTMCell.initialize_carry(
jax.random.key(0), batch_dims, hidden_size)
class SimpleLSTM(nn.Module):
@functools.partial(
nn.transforms.scan,
variable_broadcast='params',
in_axes=1, out_axes=1,
split_rngs={'params': False})
@nn.compact
def __call__(self, carry, x):
features = carry[0].shape[-1]
return nn.OptimizedLSTMCell(features)(carry, x)
@staticmethod
def initialize_carry(batch_dims, hidden_size):
return nn.OptimizedLSTMCell(hidden_size, parent=None).initialize_carry(
jax.random.key(0), (*batch_dims, hidden_size))
請注意,在新版中,我們必須從 __call__
期間的承載中擷取特徵數,並在initialize_carry
期間使用parent=None
,以避免潛在的副作用。
接下來,我們將顯示一種撰寫類似 LSTM 模組的慣用方式。這裡的主要變更將新增一個特徵
屬性在模組中,並使用它在setup
方法中初始化一個nn.scan
-ed 版本的儲存格
class SimpleLSTM(nn.Module):
@functools.partial(
nn.transforms.scan,
variable_broadcast='params',
in_axes=1, out_axes=1,
split_rngs={'params': False})
@nn.compact
def __call__(self, carry, x):
return nn.OptimizedLSTMCell()(carry, x)
@staticmethod
def initialize_carry(batch_dims, hidden_size):
return nn.OptimizedLSTMCell.initialize_carry(
jax.random.key(0), batch_dims, hidden_size)
model = SimpleLSTM()
carry = SimpleLSTM.initialize_carry((batch_size,), out_features)
variables = model.init(jax.random.key(0), carry, x)
class SimpleLSTM(nn.Module):
features: int
def setup(self):
self.scan_cell = nn.transforms.scan(
nn.OptimizedLSTMCell,
variable_broadcast='params',
in_axes=1, out_axes=1,
split_rngs={'params': False})(self.features)
@nn.compact
def __call__(self, x):
carry = self.scan_cell.initialize_carry(jax.random.key(0), x[:, 0].shape)
return self.scan_cell(carry, x)[1] # only return the output
model = SimpleLSTM(features=out_features)
variables = model.init(jax.random.key(0), x)
由於承載
可以從範例輸入中輕鬆初始化,因此我們可以將initialize_carry
的呼叫移動到__call__
方法中,這可以簡化程式碼。
開發備註#
開發新儲存格時,請考慮下列事項:
將必要的元資料包含為實例屬性。
initialize_carry
現在只需要 PRNG 金鑰和樣本輸入。新的
num_feature_axes
屬性是必需的,用來指定特徵度量。
class LSTMCell(nn.RNNCellBase):
features: int # ← All metadata is now stored within the cell instance
... # ↓
carry_init: Initializer
def initialize_carry(self, rng, input_shape) -> Carry:
...
@property
def num_feature_axes(self):
return 1
num_feature_axes
是新的 API 特性,它允許處理隨意 RNNCellBase
執行個體的程式碼,例如 RNN
模組,以推斷批次度量並決定時間軸的位置。