Open in Colab Open On GitHub

Flax 基礎#

本筆記書將引導你執行下列工作流程

  • 從 Flax 內建圖層或第三方模型實例化模型。

  • 初始化模型參數並手動執行訓練。

  • 使用 Flax 提供的最佳化器簡化訓練。

  • 序列處理參數和其他物件。

  • 建立自己的模型和管理狀態。

設定我們的環境#

以下是設定筆記書環境所需的程式碼。

# Install the latest JAXlib version.
!pip install --upgrade -q pip jax jaxlib
# Install Flax at head:
!pip install --upgrade -q git+https://github.com/google/flax.git
WARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.dev.org.tw/warnings/venv
WARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.dev.org.tw/warnings/venv
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn

使用 Flax 進行線性迴歸#

在先前的「適合急著想學 JAX 的人」筆記書中,我們完成了線性迴歸範例。眾所周知,線性迴歸也可寫成單一稠密神經網路層,我們將在以下內容中說明,以便比較這兩種做法

稠密層是一個層,它具有一個核參數 \(W\in\mathcal{M}_{m,n}(\mathbb{R})\),其中 \(m\) 是模型輸出的特徵數目,而 \(n\) 是輸入的維度,以及一個偏差參數 \(b\in\mathbb{R}^m\)。稠密層從輸入 \(x\in\mathbb{R}^n\) 輸出 \(Wx+b\)

`flax.linen>` 模組中已經提供了這個稠密層(此處匯入為 `nn`)。

# We create one dense layer instance (taking 'features' parameter as input)
model = nn.Dense(features=5)

層(以及我們從現在開始使用的模型)是 `linen.Module>` 類別的子類別。

模型參數和初始化#

參數不會儲存在模型本身。你需要使用 `init>` 函數、PRNGKey 和虛擬輸入資料來初始化參數。

key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10,)) # Dummy input data
params = model.init(key2, x) # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

注意:與 NumPy 一樣,JAX 和 Flax 是基於列的系統,這表示向量表示為列向量,而不是列向量。這可以在此處的核形狀中看到。

結果符合我們的預期:正確大小的偏差和核參數。在檯面下

  • 虛擬輸入資料 `x` 用於觸發形狀推論:我們只宣告了我們在模型輸出中想要的特徵數,而不是輸入的大小。Flax 會自行找出核的正確大小。

  • 隨機 PRNG 鍵用於觸發初始化函數(模組在此處提供其預設值)。

  • 呼叫初始化函數以產生模型將使用的初始參數集。這些函數包含參數 (PRNG Key, shape, dtype),並傳回形狀為 shape 的陣列。

  • init 函數會傳回已初始化的參數集(您也可以使用 init_with_output 方法,透過採用與 init 相同的語法,取得虛擬輸入的正向傳遞輸出。

為具有給定參數集(絕不會與模型一起儲存)的模型執行正向傳遞,我們只需使用 apply 方法,並提供要使用的參數以及輸入即可

model.apply(params, x)
DeviceArray([-0.7358944,  1.3583755, -0.7976872,  0.8168598,  0.6297793],            dtype=float32)

梯度下降法#

如果您直接跳轉至此而未瀏覽 JAX 部分,以下是我們將使用的線性迴歸公式:來自一組資料點 \(\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}\),我們嘗試尋找一組參數 \(W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m\),使得函數 \(f_{W,b}(x)=Wx+b\) 最小化均方誤差

\[\mathcal{L}(W,b)\rightarrow\frac{1}{k}\sum_{i=1}^{k} \frac{1}{2}\|y_i-f_{W,b}(x_i)\|^2_2\]

在此處,我們看到元組 \((W,b)\) 與 Dense 層的參數相符。我們將使用這些參數執行梯度下降法。我們先產生我們將使用的假資料。資料與 JAX 部份的線性迴歸 pytree 範例中完全相同。

# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.key(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
x shape: (20, 10) ; y shape: (20, 5)

我們複製與使用 jax.value_and_grad() 在 JAX pytree 線性迴歸範例中相同的訓練迴圈,但在這裡,我們可以使用 model.apply(),而不必定義我們自己的前饋函數(predict_pytree()JAX 範例 中)。

# Same as JAX version but using model.apply().
@jax.jit
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  # Vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

最後執行梯度下降法。

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_util.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)
Loss for "true" W,b:  0.023639778
Loss step 0:  38.094772
Loss step 10:  0.44692168
Loss step 20:  0.10053458
Loss step 30:  0.035822745
Loss step 40:  0.018846875
Loss step 50:  0.013864839
Loss step 60:  0.012312559
Loss step 70:  0.011812928
Loss step 80:  0.011649306
Loss step 90:  0.011595251
Loss step 100:  0.0115773035

使用 Optax 最佳化#

Flax 以前使用自己的 flax.optim 套件進行最佳化,但在 FLIP #1009 中,已棄用此套件,並改採用 Optax

Optax 的基本用法很簡單

  1. 選擇最佳化方法(例如 optax.adam)。

  2. 從參數建立最佳化器狀態(對於 Adam 最佳化器,此狀態將包含 動量值)。

  3. 使用 jax.value_and_grad() 計算損失的梯度。

  4. 在每次反覆運算中,呼叫 Optax update 函數以更新內部最佳化器狀態並建立參數的更新。然後使用 Optax 的 apply_updates 方法將更新新增到參數。

請注意,Optax 能執行更多工作:它用於將單純的梯度轉換組成更複雜的轉換,能實作廣泛的最佳化器。也支援隨著時間變更最佳化器超參數(「排程」)、將不同的更新套用到參數樹的不同部分(「遮罩」)等等。如需詳細資料,請參閱 官方文件

import optax
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)
for i in range(101):
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)
Loss step 0:  0.011576377
Loss step 10:  0.0115710115
Loss step 20:  0.011569244
Loss step 30:  0.011568661
Loss step 40:  0.011568454
Loss step 50:  0.011568379
Loss step 60:  0.011568358
Loss step 70:  0.01156836
Loss step 80:  0.01156835
Loss step 90:  0.011568353
Loss step 100:  0.011568348

序列化結果#

在我們滿意訓練結果後,我們可能會想要儲存模型參數,以便稍後再載入。Flax 提供一個序列化套件讓您能夠執行此工作。

from flax import serialization
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)
Dict output
{'params': {'bias': DeviceArray([-1.4540135, -2.0262308,  2.0806582,  1.2201802, -0.9964547],            dtype=float32), 'kernel': DeviceArray([[ 1.0106664 ,  0.19014716,  0.04533899, -0.92722285,
               0.34720102],
             [ 1.7320251 ,  0.9901233 ,  1.1662225 ,  1.1027892 ,
              -0.10574618],
             [-1.2009128 ,  0.28837162,  1.4176372 ,  0.12073109,
              -1.3132601 ],
             [-1.1944956 , -0.18993308,  0.03379077,  1.3165942 ,
               0.07996067],
             [ 0.14103189,  1.3737966 , -1.3162128 ,  0.53401774,
              -2.239638  ],
             [ 0.5643044 ,  0.813604  ,  0.31888172,  0.5359193 ,
               0.90352124],
             [-0.37948322,  1.7408353 ,  1.0788013 , -0.5041964 ,
               0.9286919 ],
             [ 0.9701384 , -1.3158673 ,  0.33630812,  0.80941117,
              -1.202457  ],
             [ 1.0198247 , -0.6198277 ,  1.0822718 , -1.8385581 ,
              -0.45790705],
             [-0.64384323,  0.4564892 , -1.1331053 , -0.68556863,
               0.17010891]], dtype=float32)}}
Bytes output
b'\x81\xa6params\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14\x1d\x1d\xba\xbf\xc4\xad\x01\xc0\x81)\x05@\xdd.\x9c?\xa8\x17\x7f\xbf\xa6kernel\xc7\xd6\x01\x93\x92\n\x05\xa7float32\xc4\xc8\x84]\x81?\xf0\xb5B>`\xb59=z^m\xbfU\xc4\xb1>\x00\xb3\xdd?\xb8x}?\xc7F\x95?2(\x8d?t\x91\xd8\xbd\x83\xb7\x99\xbfr\xa5\x93>#u\xb5?\xdcA\xf7=\xe8\x18\xa8\xbf;\xe5\x98\xbf\xd1}B\xbe0h\n=)\x86\xa8?k\xc2\xa3=\xaaj\x10>\x91\xd8\xaf?\xa9y\xa8\xbfc\xb5\x08?;V\x0f\xc0Av\x10?ZHP?wD\xa3>\x022\t?+Mg?\xa0K\xc2\xbe\xb1\xd3\xde?)\x16\x8a?\x04\x13\x01\xbf\xc1\xbem?\xfdZx?Wn\xa8\xbf\x940\xac>\x925O?\x1c\xea\x99\xbf\x9e\x89\x82?\x07\xad\x1e\xbf\xe2\x87\x8a?\xdfU\xeb\xbf\xcbr\xea\xbe\xe9\xd2$\xbf\xf4\xb8\xe9>\x98\t\x91\xbfm\x81/\xbf\x081.>'

如需載入模型,您需要使用模型參數結構的範本,例如您會從模型初始化取得的範本。在此,我們使用先前產生的 params 作為範本。請注意,這會產生一個新的變數結構,而不是就地進行突變。

透過範本來強制套用結構的目的,可避免使用者遇到下游問題,因此您需要先具備能產生參數結構的正確模型。

serialization.from_bytes(params, bytes_output)
FrozenDict({
    params: {
        bias: array([-1.4540135, -2.0262308,  2.0806582,  1.2201802, -0.9964547],
              dtype=float32),
        kernel: array([[ 1.0106664 ,  0.19014716,  0.04533899, -0.92722285,  0.34720102],
               [ 1.7320251 ,  0.9901233 ,  1.1662225 ,  1.1027892 , -0.10574618],
               [-1.2009128 ,  0.28837162,  1.4176372 ,  0.12073109, -1.3132601 ],
               [-1.1944956 , -0.18993308,  0.03379077,  1.3165942 ,  0.07996067],
               [ 0.14103189,  1.3737966 , -1.3162128 ,  0.53401774, -2.239638  ],
               [ 0.5643044 ,  0.813604  ,  0.31888172,  0.5359193 ,  0.90352124],
               [-0.37948322,  1.7408353 ,  1.0788013 , -0.5041964 ,  0.9286919 ],
               [ 0.9701384 , -1.3158673 ,  0.33630812,  0.80941117, -1.202457  ],
               [ 1.0198247 , -0.6198277 ,  1.0822718 , -1.8385581 , -0.45790705],
               [-0.64384323,  0.4564892 , -1.1331053 , -0.68556863,  0.17010891]],
              dtype=float32),
    },
})

定義您自己的模型#

Flax 讓您能定義自己的模型,這應該會比線性迴歸複雜一些。在此區段,我們將說明如何建立單純的模型。為此,您需要建立基礎 nn.Module 類別的子類別。

請記住,我們匯入了 linen as nn ,這只與新的 linen API 搭配使用

模組基礎#

模型的基本抽象化是 nn.Module 類別,而 Flax 中的每種類型預定義層(例如之前的 Dense)都是 nn.Module 的子類別。讓我們來看看並開始定義一個單純但自訂的多層感知器,也就是穿插呼叫非線性啟用函數的 Dense 層序列。

class ExplicitMLP(nn.Module):
  features: Sequence[int]

  def setup(self):
    # we automatically know what to do with lists, dicts of submodules
    self.layers = [nn.Dense(feat) for feat in self.features]
    # for single submodules, we would just write:
    # self.layer1 = nn.Dense(feat1)

  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate(self.layers):
      x = lyr(x)
      if i != len(self.layers) - 1:
        x = nn.relu(x)
    return x

key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 4.2292815e-02 -4.3807115e-02  2.9323792e-02  6.5492536e-03
  -1.7147182e-02]
 [ 1.2967806e-01 -1.4551792e-01  9.4432183e-02  1.2521387e-02
  -4.5417298e-02]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 9.3024032e-04  2.7864395e-05  2.4478821e-04  8.1344310e-04
  -1.0110770e-03]]

正如我們所見,nn.Module 子類別由以下幾項組成:

  • 資料欄位集合(nn.Module 是 Python 資料類別) - 此處我們只有型別為 Sequence[int]features欄位。

  • 一個 setup() 方法,在 __postinit__ 結尾處呼叫,您可以在其中註冊模型中需要的子模組、變數、參數。

  • 一個 __call__ 函式,傳回模型根據特定輸入輸出的結果。

  • 模型結構定義參數的 pytree,採用與模型相同的樹狀結構:params tree 包含每層一個 layers_n 子詞典,每個子詞典包含其對應稠密層的參數。此配置是明確且直接的。

注意:清單的管理方式大部分符合預期(WIP),您應注意以下指出的一些特殊情況 在此處

由於模組結構及其參數彼此無關,您無法直接針對特定輸入呼叫 model(x),因為會傳回錯誤。 __call__ 函式會封裝於 apply 函式,才能針對輸入呼叫

try:
    y = model(x) # Returns an error
except AttributeError as e:
    print(e)
"ExplicitMLP" object has no attribute "layers"

由於此處採用的模型非常簡單,可以採用另一種(但等效)方式,使用 @nn.compact 標註在 __call__ 內嵌地宣告子模組,如下所示

class SimpleMLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
      # providing a name is optional though!
      # the default autonames would be "Dense_0", "Dense_1", ...
    return x

key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 4.2292815e-02 -4.3807115e-02  2.9323792e-02  6.5492536e-03
  -1.7147182e-02]
 [ 1.2967806e-01 -1.4551792e-01  9.4432183e-02  1.2521387e-02
  -4.5417298e-02]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 9.3024032e-04  2.7864395e-05  2.4478821e-04  8.1344310e-04
  -1.0110770e-03]]

不過這兩種宣告模式之間存在一些差異,您應注意這些差異

  • setup 中,您可以命名一些子層並保留這些子層以供後續使用(例如自編碼器中的編碼器/解碼器方法)。

  • 如果您希望有下列多種方法,則必須使用 setup 宣告模組,因為 @nn.compact 標註只允許對一種方法進行標註。

  • 最後初始化方式將有所不同。請參閱以下說明以取得更多詳細資料(待辦:新增說明連結)。

模組參數#

在先前的 MLP 範例中,我們僅依賴於預先定義好的層級運算子(Denserelu)。假設 Flax 沒有提供 Dense 層級,而你想要自行建立,以下示範將透過 @nn.compact 方式來宣告新模組

class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros_init()

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # shape info.
    y = jnp.dot(inputs, kernel)
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)
initialized parameters:
 FrozenDict({
    params: {
        kernel: DeviceArray([[ 0.6503669 ,  0.86789787,  0.4604268 ],
                     [ 0.05673932,  0.9909285 , -0.63536596],
                     [ 0.76134115, -0.3250529 , -0.65221626],
                     [-0.82430327,  0.4150194 ,  0.19405058]], dtype=float32),
        bias: DeviceArray([0., 0., 0.], dtype=float32),
    },
})
output:
 [[ 0.5035518   1.8548558  -0.4270195 ]
 [ 0.0279097   0.5589246  -0.43061772]
 [ 0.3547128   1.5740999  -0.32865518]
 [ 0.5264864   1.2928858   0.10089308]]

在此範例中,我們瞭解如何宣告並透過 self.param 方法指定模型的參數。該方法會輸入 (name, init_fn, *init_args, **init_kwargs)

  • name 僅為最終顯示於參數結構中的參數名稱。

  • init_fn 為含輸入 (PRNGKey, *init_args, **init_kwargs) 的函式,會傳回陣列,其中 init_argsinit_kwargs 為呼叫初始化函式所需參數。

  • init_argsinit_kwargs 為提供給初始化函式的參數。

也可以在 setup 方法宣告這類參數;由於 Flax 會在第一次呼叫函式時進行延遲初始化,因此無法使用外部形狀推論。

變數和變數集合#

從目前為止的說明得知,使用模型表示使用以下內容:

  • nn.Module 的子類別;

  • 模型的參數 pytree (通常來自 model.init());

不過這還不足以涵蓋機器學習,尤其是神經網路所需的各種功能。在某些情況下,您可能希望神經網路在執行期間保持追蹤一些內部狀態(例如批次正規化層級)。除了模型參數外,透過 variable 方法,您有辦法宣告變數。

為了示範,我們將實作一個批次正規化的簡化類似機制:我們在訓練時會儲存平均數值,並將其從輸入中減去。若要進行實際的批次正規化,應使用(並查看)此處的實作。

class BiasAdderWithRunningMean(nn.Module):
  decay: float = 0.99

  @nn.compact
  def __call__(self, x):
    # easy pattern to detect if we're initializing via empty variable tree
    is_initialized = self.has_variable('batch_stats', 'mean')
    ra_mean = self.variable('batch_stats', 'mean',
                            lambda s: jnp.zeros(s),
                            x.shape[1:])
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)

    return x - ra_mean.value + bias


key1, key2 = random.split(random.key(0), 2)
x = jnp.ones((10,5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)
initialized variables:
 FrozenDict({
    batch_stats: {
        mean: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
    },
    params: {
        bias: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
    },
})
updated state:
 FrozenDict({
    batch_stats: {
        mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
    },
})

在這裡,updated_state 只回傳模型套用於資料時會變動的狀態變數。要更新變數和取得模型的新參數,可以使用下列模式

for val in [1.0, 2.0, 3.0]:
  x = val * jnp.ones((10,5))
  y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
  old_state, params = flax.core.pop(variables, 'params')
  variables = flax.core.freeze({'params': params, **updated_state})
  print('updated state:\n', updated_state) # Shows only the mutable part
updated state:
 FrozenDict({
    batch_stats: {
        mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
    },
})
updated state:
 FrozenDict({
    batch_stats: {
        mean: DeviceArray([[0.0299, 0.0299, 0.0299, 0.0299, 0.0299]], dtype=float32),
    },
})
updated state:
 FrozenDict({
    batch_stats: {
        mean: DeviceArray([[0.059601, 0.059601, 0.059601, 0.059601, 0.059601]], dtype=float32),
    },
})

從這個簡化的範例,您可以衍生完整的 BatchNorm 實作,或包含狀態的任何層。最後,我們來新增一個最佳化器,看看如何同時使用最佳化器更新的參數和狀態變數。

這個範例不會做任何事,僅供示範用途。

from functools import partial

@partial(jax.jit, static_argnums=(0, 1))
def update_step(tx, apply_fn, x, opt_state, params, state):

  def loss(params):
    y, updated_state = apply_fn({'params': params, **state},
                                x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum()
    return l, updated_state

  (l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return opt_state, params, state

x = jnp.ones((10,5))
variables = model.init(random.key(0), x)
state, params = flax.core.pop(variables, 'params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(3):
  opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)
  print('Updated state: ', state)
Updated state:  FrozenDict({
    batch_stats: {
        mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
    },
})
Updated state:  FrozenDict({
    batch_stats: {
        mean: DeviceArray([[0.0199, 0.0199, 0.0199, 0.0199, 0.0199]], dtype=float32),
    },
})
Updated state:  FrozenDict({
    batch_stats: {
        mean: DeviceArray([[0.029701, 0.029701, 0.029701, 0.029701, 0.029701]], dtype=float32),
    },
})

請注意,上述函數具有相當詳細的簽章,且實際上不能搭配 jax.jit() 使用,因為函數引數不是「有效的 JAX 類型」。

Flax 提供一個便利的包裝程式 - TrainState - 可簡化上述程式碼。請查看 flax.training.train_state.TrainState 以進一步了解。

使用 jax2tf 匯出到 Tensorflow 的 SavedModel#

JAX 發佈了一個名為 jax2tf 的實驗轉換器,可將訓練過的 Flax 模型轉換成 Tensorflow 的 SavedModel 格式 (因此可供 TF HubTF.liteTF.js 或其他下游應用程式使用)。此資源庫包含更多文件,且有各種 Flax 範例。