將我的程式庫升級到 Linen#

自 Flax v0.4.0 起,flax.nn 不再存在,新的 Linen API 取而代之,位置為 flax.linen。如果你的程式碼庫仍在使用舊 API,你可以使用此升級指南將其升級到 Linen。

定義簡單的 Flax 模組#

from flax import nn

class Dense(base.Module):
  def apply(self,
            inputs,
            features,
            use_bias=True,
            kernel_init=default_kernel_init,
            bias_init=initializers.zeros_init()):

    kernel = self.param('kernel',
      (inputs.shape[-1], features), kernel_init)
    y = jnp.dot(inputs, kernel)
    if use_bias:
      bias = self.param(
        'bias', (features,), bias_init)
      y = y + bias
    return y
from flax import linen as nn  # [1]

class Dense(nn.Module):
  features: int  # [2]
  use_bias: bool = True
  kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
  bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init()

  @nn.compact
  def __call__(self, inputs):  # [3]
    kernel = self.param('kernel',
      self.kernel_init, (inputs.shape[-1], self.features))  # [4]
    y = jnp.dot(inputs, kernel)
    if self.use_bias:
      bias = self.param(
        'bias', self.bias_init, (self.features,))  # [5]
      y = y + bias
    return y

  1. from flax import nn 替換為 from flax import linen as nn

  2. 將參數移至 `apply` 的資料類別屬性中。加入型別註解(或使用 `Any` 型別來略過)。

  3. 將方法 `apply` 重新命名為 `__call__`,然後使用 `@compact` 包覆(可選)。`@compact` 包覆的方法可以在方法中直接定義子模組(就像舊版 Flax)。你只能使用 `@compact` 來包覆單一方法。或者,你可以定義 `setup` 方法。有關更多詳細資訊,請參閱我們的另一個 HOWTO 我該使用 setup 還是 nn.compact?

  4. 透過 `self.<attr>` 來存取方法中的資料類別屬性值,例如 `self.features`。

  5. 將形狀移至 self.param 參數列表的尾端(初始化函數可以使用任何參數清單)。

在其他模組內使用 Flax 模組#

class Encoder(nn.Module):

  def apply(self, x):
    x = nn.Dense(x, 500)
    x = nn.relu(x)
    z = nn.Dense(x, 500, name="latents")
    return z
class Encoder(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(500)(x)  # [1]
    x = nn.relu(x)
    z = nn.Dense(500, name='latents')(x)  # [2]
    return z

  1. 模組建構函式不再回傳輸出。取而代之的是,模組建構函式會像一般的建構函式一樣運作,回傳模組實例。這些實例可以像一般 Python 一樣分享(取代使用舊版 Flax 的 `.shared()`)。由於大多數模組都實作 `__call__`,你仍然可以使用舊版 Flax 簡潔的語法。

  2. 名稱可以選擇性地傳遞給所有模組建構函式。

分享子模組和定義多個方法#

class AutoEncoder(nn.Module):
  def _create_submodules(self):
    return Decoder.shared(name="encoder")

  def apply(self, x, z_rng, latents=20):
    decoder = self._create_decoder()
    z = Encoder(x, latents, name="encoder")
    return decoder(z)

  @nn.module_method
  def generate(self, z, **unused_kwargs):
    decoder = self._create_decoder()
    return nn.sigmoid(decoder(z))
class AutoEncoder(nn.Module):
  latents: int = 20

  def setup(self):  # [1]
    self.encoder = Encoder(self.latents)  # [2]
    self.decoder = Decoder()

  def __call__(self, x):  # [3]
    z = self.encoder(x)
    return self.decoder(z)

  def generate(self, z):  # [4]
    return nn.sigmoid(self.decoder(z))

  1. 使用 setup 取代 __init__ 函式,後者已由資料類別庫定義。Flax 會在模組可以使用後準備叫用 setup。(你可以對所有模組執行這項操作,而不用使用 @compact,但我們喜歡 @compact 對模組定義與使用的所在位置進行共同定位,特別是在有迴圈或條件式時)。

  2. 如同一般的 Python,透過在初始化期間指派給 self 來共用次模組。類似於 PyTorch,self.encoder 會自動具備名稱 "encoder"

  3. 我們沒有在此使用 @compact,因為我們未定義任何內嵌次模組(所有次模組都在 setup 中定義)。

  4. 只定義其他方法,如同一般的 Python。

Module.partial 在其他模組內#

# no import

class ResNet(nn.Module):
  """ResNetV1."""


  def apply(self, x,
            stage_sizes,
            num_filters=64,
            train=True):
    conv = nn.Conv.partial(bias=False)
    norm = nn.BatchNorm.partial(
        use_running_average=not train,
        momentum=0.9, epsilon=1e-5)

    x = conv(x, num_filters, (7, 7), (2, 2),
            padding=[(3, 3), (3, 3)],
            name='conv_init')
    x = norm(x, name='bn_init')

    # [...]
    return x
from functools import partial

class ResNet(nn.Module):
  """ResNetV1."""
  stage_sizes: Sequence[int]
  num_filters: int = 64
  train: bool = True

  @nn.compact
  def __call__(self, x):
    conv = partial(nn.Conv, use_bias=False)
    norm = partial(nn.BatchNorm,
                  use_running_average=not self.train,
                  momentum=0.9, epsilon=1e-5)

    x = conv(self.num_filters, (7, 7), (2, 2),
            padding=[(3, 3), (3, 3)],
            name='conv_init')(x)
    x = norm(name='bn_init')(x)

    # [...]
    return x

使用一般的 functools.partial 取代 Module.partial。其他部分保持不變。

頂級訓練程式碼樣式#

def create_model(key):
  _, initial_params = CNN.init_by_shape(
    key, [((1, 28, 28, 1), jnp.float32)])
  model = nn.Model(CNN, initial_params)
  return model

def create_optimizer(model, learning_rate):
  optimizer_def = optim.Momentum(learning_rate=learning_rate)
  optimizer = optimizer_def.create(model)
  return optimizer

def cross_entropy_loss(*, logits, labels):
  one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
  return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))

def loss_fn(model):
  logits = model(batch['image'])
  one_hot = jax.nn.one_hot(batch['label'], num_classes=10)
  loss = -jnp.mean(jnp.sum(one_hot_labels * batch['label'],
                           axis=-1))
  return loss, logits
def create_train_state(rng, config):  # [1]
  variables = CNN().init(rng, jnp.ones([1, 28, 28, 1]))  # [2]
  params = variables['params']  # [3]
  tx = optax.sgd(config.learning_rate, config.momentum)  # [4]
  return train_state.TrainState.create(
      apply_fn=CNN.apply, params=params, tx=tx)


def loss_fn(params):
  logits = CNN().apply({'params': params}, batch['image'])  # [5]
  one_hot = jax.nn.one_hot(batch['label'], 10)
  loss = jnp.mean(optax.softmax_cross_entropy(logits=logits,
                                              labels=one_hot))
  return loss, logits

  1. 我們不再使用 Model 抽象 – 我們直接傳遞參數,通常會封裝在 TrainState 物件中,這個物件可以直接傳遞至 JAX 轉換。

  2. 要計算初始參數,請建構模組實例並叫用 initinit_with_output。我們尚未移植 init_by_shape,因為這項函式執行一些我們不喜歡的魔法(它根據形狀評估函式,但無論如何都會傳回真實值)。因此,您現在應該將具體值傳遞給初始化器函式,您可以透過將初始化器函式包裝在 jax.jit 中來最佳化初始化,強烈建議使用此方法來避免執行完整的正向傳遞。

  3. Linen 將參數概括為變數。參數是由變數所組成的「集合」之一。變數是巢狀的 dict,其中頂層鍵反映了不同的變數集合,其中「參數」是其中之一。請參閱 變數文件 以取得更多詳細資訊。

  4. 我們建議使用 Optax 最佳化器。更多詳細資訊,請參閱我們稱為 將我的程式庫升級至 Optax 的單獨 HOWTO。

  5. 若要對模型進行預測,請在頂層建立一個執行個體(這是免費的,只是建構函式屬性的包裝器),並呼叫 apply 方法(它內部會呼叫 __call__)。

不可訓練變數(「狀態」):在模組內使用#

class BatchNorm(nn.Module):
  def apply(self, x):
    # [...]
    ra_mean = self.state(
      'mean', (x.shape[-1], ), initializers.zeros_init())
    ra_var = self.state(
      'var', (x.shape[-1], ), initializers.ones_init())
    # [...]
class BatchNorm(nn.Module):
  def __call__(self, x):
    # [...]
    ra_mean = self.variable(
      'batch_stats', 'mean', initializers.zeros_init(), (x.shape[-1], ))
    ra_var = self.variable(
      'batch_stats', 'var', initializers.ones_init(), (x.shape[-1], ))
    # [...]

第一個引數是變數集合的名稱(「param」是唯一永遠可用的變數集合)。某些集合在頂層訓練程式碼中可能被視為可變動的,而另一些則視為不可變的(詳細資訊請參閱下一部分)。Flax 也讓您在模組內使用 JAX 轉換時,可以對每個變數集合執行不同的處理。

不可訓練變數(「狀態」):頂層訓練程式碼模式#

# initial params and state
def initial_model(key, init_batch):
  with nn.stateful() as initial_state:
    _, initial_params = ResNet.init(key, init_batch)
  model = nn.Model(ResNet, initial_params)
  return model, init_state


# updates batch statistics during training
def loss_fn(model, model_state):
  with nn.stateful(model_state) as new_model_state:
    logits = model(batch['image'])
  # [...]



# reads immutable batch statistics during evaluation
def eval_step(model, model_state, batch):
  with nn.stateful(model_state, mutable=False):
    logits = model(batch['image'], train=False)
  return compute_metrics(logits, batch['label'])
# initial variables ({"param": ..., "batch_stats": ...})
def initial_variables(key, init_batch):
  return ResNet().init(key, init_batch)  # [1]



# updates batch statistics during training
def loss_fn(params, batch_stats):
  variables = {'params': params, 'batch_stats': batch_stats}  # [2]
  logits, new_variables = ResNet(train=true).apply(
    variables, batch['image'], mutable=['batch_stats'])  # [3]
  new_batch_stats = new_variables['batch_stats']
  # [...]


# reads immutable batch statistics during evaluation
def eval_step(params, batch_stats, batch):
  variables = {'params': params, 'batch_stats': batch_stats}
  logits = ResNet(train=False).apply(
    variables, batch['image'], mutable=False)  # [4]
  return compute_metrics(logits, batch['label'])

  1. init 會回傳一個變數字典,例如 {"param": ..., "batch_stats": ...}(請參閱 變數文件)。

  2. 將不同變數集合合併到變數字典中。

  3. 在訓練期間,batch_stats 變數集合會變更。由於我們在可變動引數中指定這一點,module.apply 的傳回值會變成一個由 output, new_variables 組成的有序對。

  4. 在評估期間,如果我們意外地在訓練模式中套用批次正規化,我們想提出錯誤。透過將 mutable=False 傳遞到 module.apply 中,我們強制執行這一點。由於沒有變數被變動,傳回值再次僅為輸出。

載入預 Linen 檢查點#

雖然大多數 Linen 模組都應該能夠使用預 Linen 權重而不進行任何修改,但有一個陷阱:在預 Linen API 中,子模組會遞增編號,與子模組類別無關。Flax 已更改此行為,讓每個模組類別保持獨立的子模組計數。

在預 Linen 中,param 具有以下結構

{'Conv_0': { ... }, 'Dense_1': { ... } }

但在 Linen 中,結構會變成這樣

{'Conv_0': { ... }, 'Dense_0': { ... } }

待辦事項:在此新增示例說明如何載入新的 TrainState 物件。

隨機性#

def dropout(inputs, rate, deterministic=False):
  keep_prob = 1. - rate
  if deterministic:
    return inputs
  else:
    mask = random.bernoulli(
    make_rng(), p=keep_prob, shape=inputs.shape)
    return lax.select(
      mask, inputs / keep_prob, jnp.zeros_like(inputs))


def loss_fn(model, dropout_rng):
  with nn.stochastic(dropout_rng):
    logits = model(inputs)
class Dropout(nn.Module):
  rate: float

  @nn.compact
  def __call__(self, inputs, deterministic=False):
    keep_prob = 1. - self.rate
    if deterministic:
      return inputs
    else:
      mask = random.bernoulli(
        self.make_rng('dropout'), p=keep_prob, shape=inputs.shape)  # [1]
      return lax.select(
        mask, inputs / keep_prob, jnp.zeros_like(inputs))


def loss_fn(params, dropout_rng):
  logits = Transformer().apply(
    {'params': params}, inputs, rngs={'dropout': dropout_rng})  # [2]

  1. Linen 中的 RNG 有「類型」-此範例中的類型為 'dropout'。不同的類型在 JAX 轉換中可以有不同的處理方式(例如,您希望序列模型中的每個時間步驟都使用相同的 dropout 遮罩,還是不同的遮罩?)

  2. 與其使用 nn.stochastic 背景管理員,您可以將 RNG 明確傳遞至 module.apply。在評估期間您不會傳遞任何 RNG-如此一來,如果您意外在非確定性模式下使用 dropout,self.make_rng('dropout') 會產生錯誤。

提升的轉換#

在 Linen 中,與其直接使用 JAX 轉換,我們使用「提升的轉換」,亦即套用至 Flax 模組的 JAX 轉換。

若要深入瞭解,請參閱有關 提升的轉換 的設計說明。

待辦事項:提供 jax.scan_in_dim(Linen 之前版本)與 nn.scan(Linen)的範例。