循環#
用於 Flax 的 RNN 模組。
- class flax.nnx.nn.recurrent.LSTMCell(*args, **kwargs)[原始碼]#
LSTM 單元。
單元的數學定義如下
\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]其中 x 是輸入,h 是前一個時間步的輸出,c 是記憶。
- __call__(carry, inputs)[原始碼]#
一個長短期記憶 (LSTM) 單元。
- 參數
carry – LSTM 單元的隱藏狀態,使用
LSTMCell.initialize_carry
初始化。inputs – 一個 ndarray,包含目前時間步的輸入。除了最後一個維度外,所有維度都被視為批次維度。
- 回傳
一個包含新 carry 和輸出的 tuple。
- initialize_carry(input_shape, rngs=None)[原始碼]#
初始化 RNN 單元的 carry。
- 參數
rng – 傳遞至 init_fn 的隨機數產生器。
input_shape – 一個 tuple,提供單元輸入的形狀。
- 回傳
給定 RNN 單元的已初始化 carry。
方法
initialize_carry
(input_shape[, rngs])初始化 RNN 單元的 carry。
- class flax.nnx.nn.recurrent.OptimizedLSTMCell(*args, **kwargs)[原始碼]#
更有效率的 LSTM 單元,可在矩陣乘法之前串聯狀態元件。
這些參數與
LSTMCell
相容。請注意,只要隱藏大小大致 <= 2048 個單元,此單元通常會比LSTMCell
快。單元的數學定義與
LSTMCell
相同,如下所示\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]其中 x 是輸入,h 是前一個時間步的輸出,c 是記憶。
- gate_fn#
用於閘門的激活函數 (預設:sigmoid)。
- activation_fn#
用於輸出和記憶更新的激活函數 (預設:tanh)。
- kernel_init#
用於轉換輸入的 kernel 的初始化器函數 (預設:lecun_normal)。
- recurrent_kernel_init#
用於轉換隱藏狀態的 kernel 的初始化器函數 (預設:initializers.orthogonal())。
- bias_init#
用於偏置參數的初始化器 (預設:initializers.zeros_init())。
- dtype#
計算的 dtype (預設:從輸入和參數推斷)。
- param_dtype#
傳遞給參數初始化器的 dtype (預設:float32)。
- __call__(carry, inputs)[原始碼]#
一個最佳化的長短期記憶 (LSTM) 單元。
- 參數
carry – LSTM 單元的隱藏狀態,使用
LSTMCell.initialize_carry
初始化。inputs – 一個 ndarray,包含目前時間步的輸入。除了最後一個維度外,所有維度都被視為批次維度。
- 回傳
一個包含新 carry 和輸出的 tuple。
- initialize_carry(input_shape, rngs=None)[原始碼]#
初始化 RNN 單元的 carry。
- 參數
rngs – 傳遞至 init_fn 的隨機數產生器。
input_shape – 一個 tuple,提供單元輸入的形狀。
- 回傳
給定 RNN 單元的已初始化 carry。
方法
initialize_carry
(input_shape[, rngs])初始化 RNN 單元的 carry。
- class flax.nnx.nn.recurrent.SimpleCell(*args, **kwargs)[原始碼]#
簡單單元。
單元的數學定義如下
\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h) \end{array}\]其中 x 是輸入,h 是前一個時間步的輸出。
如果 residual 為 True,
\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h + h) \end{array}\]- __call__(carry, inputs)[原始碼]#
執行 RNN 單元。
- 參數
carry – RNN 單元的隱藏狀態。
inputs – 一個 ndarray,包含目前時間步的輸入。除了最後一個維度外,所有維度都被視為批次維度。
- 回傳
一個包含新 carry 和輸出的 tuple。
- initialize_carry(input_shape, rngs=None)[原始碼]#
初始化 RNN 單元的 carry。
- 參數
rng – 傳遞至 init_fn 的隨機數產生器。
input_shape – 一個 tuple,提供單元輸入的形狀。
- 回傳
給定 RNN 單元的已初始化 carry。
方法
initialize_carry
(input_shape[, rngs])初始化 RNN 單元的 carry。
- class flax.nnx.nn.recurrent.GRUCell(*args, **kwargs)[原始碼]#
GRU 單元。
單元的數學定義如下
\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]其中 x 是輸入,h 是前一個時間步的輸出。
- in_features#
輸入特徵的數量。
輸出特徵的數量。
- gate_fn#
用於閘門的激活函數 (預設:sigmoid)。
- activation_fn#
用於輸出和記憶更新的激活函數 (預設:tanh)。
- kernel_init#
用於轉換輸入的 kernel 的初始化器函數 (預設:lecun_normal)。
- recurrent_kernel_init#
用於轉換隱藏狀態的 kernel 的初始化器函數 (預設:initializers.orthogonal())。
- bias_init#
用於偏置參數的初始化器 (預設:initializers.zeros_init())。
- dtype#
計算的資料類型(預設:None)。
- param_dtype#
傳遞給參數初始化器的 dtype (預設:float32)。
- __call__(carry, inputs)[原始碼]#
閘道循環單元 (GRU) 單元。
- 參數
carry – GRU 單元的隱藏狀態,使用
GRUCell.initialize_carry
初始化。inputs – 一個 ndarray,包含目前時間步的輸入。除了最後一個維度外,所有維度都被視為批次維度。
- 回傳
一個包含新 carry 和輸出的 tuple。
- initialize_carry(input_shape, rngs=None)[原始碼]#
初始化 RNN 單元的 carry。
- 參數
rngs – 傳遞至 init_fn 的隨機數產生器。
input_shape – 一個 tuple,提供單元輸入的形狀。
- 回傳
給定 RNN 單元的已初始化 carry。
方法
initialize_carry
(input_shape[, rngs])初始化 RNN 單元的 carry。
- class flax.nnx.nn.recurrent.RNN(*args, **kwargs)[原始碼]#
RNN
模組採用任何RNNCellBase
實例,並使用flax.nnx.scan()
將其應用於序列上。使用
flax.nnx.scan()
。- __call__(inputs, *, initial_carry=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None, rngs=None)[原始碼]#
將自身呼叫為函式。
方法
- class flax.nnx.nn.recurrent.Bidirectional(*args, **kwargs)[原始碼]#
在兩個方向上處理輸入並合併結果。
使用範例
>>> from flax import nnx >>> import jax >>> import jax.numpy as jnp >>> # Define forward and backward RNNs >>> forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) >>> backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) >>> # Create Bidirectional layer >>> layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn) >>> # Input data >>> x = jnp.ones((2, 3, 3)) >>> # Apply the layer >>> out = layer(x) >>> print(out.shape) (2, 3, 8)
- __call__(inputs, *, initial_carry=None, rngs=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[原始碼]#
將自身呼叫為函式。
方法
- flax.nnx.nn.recurrent.flip_sequences(inputs, seq_lengths, num_batch_dims, time_major)[原始碼]#
沿時間軸翻轉輸入序列。
此函式可用於為雙向 LSTM 的反向準備輸入。它解決了當天真地翻轉儲存在矩陣中的多個填充序列時,第一個元素將是那些被填充的序列的填充值問題。此函式會將填充保留在末尾,同時翻轉其餘的元素。
範例
>>> from flax.nnx.nn.recurrent import flip_sequences >>> from jax import numpy as jnp >>> inputs = jnp.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]]) >>> lengths = jnp.array([1, 2, 3]) >>> flip_sequences(inputs, lengths, 1, False) Array([[1, 0, 0], [3, 2, 0], [6, 5, 4]], dtype=int32)
- 參數
inputs – 輸入 ID 的陣列 <int>[batch_size, seq_length]。
lengths – 每個序列的長度 <int>[batch_size]。
- 回傳
具有翻轉輸入的 ndarray。