模型手術#

通常,Flax 模組和最佳化器會追蹤和更新參數。但有時候您可能想進行一些模型手術,並自行調整參數張量。本指南會教您如何操作。

設定#

!pip install --upgrade -q pip jax jaxlib flax
import functools

import jax
import jax.numpy as jnp
from flax import traverse_util
from flax import linen as nn
from flax.core import freeze
import jax
import optax

用 Flax Modules 進行手術#

讓我們為我們的範例建立一個小型卷積神經網路模型。

一如往常,您可以執行 CNN.init(...)['params'] 來取得 params,並在訓練的每個步驟中傳遞及修改。

class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
      x = nn.Conv(features=32, kernel_size=(3, 3))(x)
      x = nn.relu(x)
      x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
      x = nn.Conv(features=64, kernel_size=(3, 3))(x)
      x = nn.relu(x)
      x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
      x = x.reshape((x.shape[0], -1))
      x = nn.Dense(features=256)(x)
      x = nn.relu(x)
      x = nn.Dense(features=10)(x)
      x = nn.log_softmax(x)
      return x

def get_initial_params(key):
    init_shape = jnp.ones((1, 28, 28, 1), jnp.float32)
    initial_params = CNN().init(key, init_shape)['params']
    return initial_params

key = jax.random.key(0)
params = get_initial_params(key)

jax.tree_util.tree_map(jnp.shape, params)

請注意,作為 params 返回的內容是 FrozenDict,其中包含幾個 JAX 陣列,作為核心與偏差。

FrozenDict 不過是個唯讀字典,Flax 讓它成為唯讀是因為 JAX 的函數特性:JAX 陣列不可變,而且新的 params 需要取代舊的 params。讓字典成為唯讀可以確保字典在訓練與更新期間不會意外地進行局部突變。

在 Flax 模組外實際修改參數的一種方法是,明確地將它扁平化並建立一個可變字典。請注意,您可以使用分隔號 sep 來加入所有巢狀鍵。如果沒有指定 sep,則鍵將會是所有巢狀鍵的元組。

# Get a flattened key-value list.
flat_params = traverse_util.flatten_dict(params, sep='/')

jax.tree_util.tree_map(jnp.shape, flat_params)

現在您可以對參數執行任何您想要的操作。完成後,將它扁平化回原狀並在未來的訓練中使用。

# Somehow modify a layer
dense_kernel = flat_params['Dense_1/kernel']
flat_params['Dense_1/kernel'] = dense_kernel / jnp.linalg.norm(dense_kernel)

# Unflatten.
unflat_params = traverse_util.unflatten_dict(flat_params, sep='/')
# Refreeze.
unflat_params = freeze(unflat_params)
jax.tree_util.tree_map(jnp.shape, unflat_params)

動手術與最佳化器#

當使用 Optax 作為最佳化器時,opt_state 實際上是在組成最佳化器的個別梯度轉換狀態中的巢狀元組。這些狀態包含鏡像參數樹的 pytree,且可以以相同的方式修改:扁平化、修改、取消扁平化,然後重新建立一個新的最佳化器狀態,鏡像原始狀態。

tx = optax.adam(1.0)
opt_state = tx.init(params)

# The optimizer state is a tuple of gradient transformation states.
jax.tree_util.tree_map(jnp.shape, opt_state)

最佳化器狀態內的 pytree 遵循與參數相同的結構,且可以以完全相同的方式扁平化/修改。

flat_mu = traverse_util.flatten_dict(opt_state[0].mu, sep='/')
flat_nu = traverse_util.flatten_dict(opt_state[0].nu, sep='/')

jax.tree_util.tree_map(jnp.shape, flat_mu)

在修改後,重新建立最佳化器狀態。在未來的訓練中使用這個狀態。

opt_state = (
    opt_state[0]._replace(
        mu=traverse_util.unflatten_dict(flat_mu, sep='/'),
        nu=traverse_util.unflatten_dict(flat_nu, sep='/'),
    ),
) + opt_state[1:]
jax.tree_util.tree_map(jnp.shape, opt_state)