循環#

用於 Flax 的 RNN 模組。

class flax.nnx.nn.recurrent.LSTMCell(*args, **kwargs)[原始碼]#

LSTM 單元。

單元的數學定義如下

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

其中 x 是輸入,h 是前一個時間步的輸出,c 是記憶。

__call__(carry, inputs)[原始碼]#

一個長短期記憶 (LSTM) 單元。

參數
  • carry – LSTM 單元的隱藏狀態,使用 LSTMCell.initialize_carry 初始化。

  • inputs – 一個 ndarray,包含目前時間步的輸入。除了最後一個維度外,所有維度都被視為批次維度。

回傳

一個包含新 carry 和輸出的 tuple。

initialize_carry(input_shape, rngs=None)[原始碼]#

初始化 RNN 單元的 carry。

參數
  • rng – 傳遞至 init_fn 的隨機數產生器。

  • input_shape – 一個 tuple,提供單元輸入的形狀。

回傳

給定 RNN 單元的已初始化 carry。

方法

initialize_carry(input_shape[, rngs])

初始化 RNN 單元的 carry。

class flax.nnx.nn.recurrent.OptimizedLSTMCell(*args, **kwargs)[原始碼]#

更有效率的 LSTM 單元,可在矩陣乘法之前串聯狀態元件。

這些參數與 LSTMCell 相容。請注意,只要隱藏大小大致 <= 2048 個單元,此單元通常會比 LSTMCell 快。

單元的數學定義與 LSTMCell 相同,如下所示

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

其中 x 是輸入,h 是前一個時間步的輸出,c 是記憶。

gate_fn#

用於閘門的激活函數 (預設:sigmoid)。

activation_fn#

用於輸出和記憶更新的激活函數 (預設:tanh)。

kernel_init#

用於轉換輸入的 kernel 的初始化器函數 (預設:lecun_normal)。

recurrent_kernel_init#

用於轉換隱藏狀態的 kernel 的初始化器函數 (預設:initializers.orthogonal())。

bias_init#

用於偏置參數的初始化器 (預設:initializers.zeros_init())。

dtype#

計算的 dtype (預設:從輸入和參數推斷)。

param_dtype#

傳遞給參數初始化器的 dtype (預設:float32)。

__call__(carry, inputs)[原始碼]#

一個最佳化的長短期記憶 (LSTM) 單元。

參數
  • carry – LSTM 單元的隱藏狀態,使用 LSTMCell.initialize_carry 初始化。

  • inputs – 一個 ndarray,包含目前時間步的輸入。除了最後一個維度外,所有維度都被視為批次維度。

回傳

一個包含新 carry 和輸出的 tuple。

initialize_carry(input_shape, rngs=None)[原始碼]#

初始化 RNN 單元的 carry。

參數
  • rngs – 傳遞至 init_fn 的隨機數產生器。

  • input_shape – 一個 tuple,提供單元輸入的形狀。

回傳

給定 RNN 單元的已初始化 carry。

方法

initialize_carry(input_shape[, rngs])

初始化 RNN 單元的 carry。

class flax.nnx.nn.recurrent.SimpleCell(*args, **kwargs)[原始碼]#

簡單單元。

單元的數學定義如下

\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h) \end{array}\]

其中 x 是輸入,h 是前一個時間步的輸出。

如果 residualTrue

\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h + h) \end{array}\]
__call__(carry, inputs)[原始碼]#

執行 RNN 單元。

參數
  • carry – RNN 單元的隱藏狀態。

  • inputs – 一個 ndarray,包含目前時間步的輸入。除了最後一個維度外,所有維度都被視為批次維度。

回傳

一個包含新 carry 和輸出的 tuple。

initialize_carry(input_shape, rngs=None)[原始碼]#

初始化 RNN 單元的 carry。

參數
  • rng – 傳遞至 init_fn 的隨機數產生器。

  • input_shape – 一個 tuple,提供單元輸入的形狀。

回傳

給定 RNN 單元的已初始化 carry。

方法

initialize_carry(input_shape[, rngs])

初始化 RNN 單元的 carry。

class flax.nnx.nn.recurrent.GRUCell(*args, **kwargs)[原始碼]#

GRU 單元。

單元的數學定義如下

\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]

其中 x 是輸入,h 是前一個時間步的輸出。

in_features#

輸入特徵的數量。

hidden_features#

輸出特徵的數量。

gate_fn#

用於閘門的激活函數 (預設:sigmoid)。

activation_fn#

用於輸出和記憶更新的激活函數 (預設:tanh)。

kernel_init#

用於轉換輸入的 kernel 的初始化器函數 (預設:lecun_normal)。

recurrent_kernel_init#

用於轉換隱藏狀態的 kernel 的初始化器函數 (預設:initializers.orthogonal())。

bias_init#

用於偏置參數的初始化器 (預設:initializers.zeros_init())。

dtype#

計算的資料類型(預設:None)。

param_dtype#

傳遞給參數初始化器的 dtype (預設:float32)。

__call__(carry, inputs)[原始碼]#

閘道循環單元 (GRU) 單元。

參數
  • carry – GRU 單元的隱藏狀態,使用 GRUCell.initialize_carry 初始化。

  • inputs – 一個 ndarray,包含目前時間步的輸入。除了最後一個維度外,所有維度都被視為批次維度。

回傳

一個包含新 carry 和輸出的 tuple。

initialize_carry(input_shape, rngs=None)[原始碼]#

初始化 RNN 單元的 carry。

參數
  • rngs – 傳遞至 init_fn 的隨機數產生器。

  • input_shape – 一個 tuple,提供單元輸入的形狀。

回傳

給定 RNN 單元的已初始化 carry。

方法

initialize_carry(input_shape[, rngs])

初始化 RNN 單元的 carry。

class flax.nnx.nn.recurrent.RNN(*args, **kwargs)[原始碼]#

RNN 模組採用任何 RNNCellBase 實例,並使用 flax.nnx.scan() 將其應用於序列上。

使用 flax.nnx.scan()

__call__(inputs, *, initial_carry=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None, rngs=None)[原始碼]#

將自身呼叫為函式。

方法

class flax.nnx.nn.recurrent.Bidirectional(*args, **kwargs)[原始碼]#

在兩個方向上處理輸入並合併結果。

使用範例

>>> from flax import nnx
>>> import jax
>>> import jax.numpy as jnp

>>> # Define forward and backward RNNs
>>> forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
>>> backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))

>>> # Create Bidirectional layer
>>> layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn)

>>> # Input data
>>> x = jnp.ones((2, 3, 3))

>>> # Apply the layer
>>> out = layer(x)
>>> print(out.shape)
(2, 3, 8)
__call__(inputs, *, initial_carry=None, rngs=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[原始碼]#

將自身呼叫為函式。

方法

flax.nnx.nn.recurrent.flip_sequences(inputs, seq_lengths, num_batch_dims, time_major)[原始碼]#

沿時間軸翻轉輸入序列。

此函式可用於為雙向 LSTM 的反向準備輸入。它解決了當天真地翻轉儲存在矩陣中的多個填充序列時,第一個元素將是那些被填充的序列的填充值問題。此函式會將填充保留在末尾,同時翻轉其餘的元素。

範例

>>> from flax.nnx.nn.recurrent import flip_sequences
>>> from jax import numpy as jnp
>>> inputs = jnp.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]])
>>> lengths = jnp.array([1, 2, 3])
>>> flip_sequences(inputs, lengths, 1, False)
Array([[1, 0, 0],
       [3, 2, 0],
       [6, 5, 4]], dtype=int32)
參數
  • inputs – 輸入 ID 的陣列 <int>[batch_size, seq_length]。

  • lengths – 每個序列的長度 <int>[batch_size]。

回傳

具有翻轉輸入的 ndarray。