從 Haiku 遷移至 Flax#

本指南示範 Haiku 和 Flax NNX 模型之間的差異,提供並排的範例程式碼,以協助您從 Haiku 遷移至 Flax NNX API。

如果您是 Flax NNX 的新手,請務必熟悉 Flax NNX 基礎知識,其中涵蓋 nnx.Module 系統、Flax 轉換,以及附帶範例的 函數式 API

讓我們從一些匯入開始。

基本模組定義#

Haiku 和 Flax 都使用 Module 類別作為表達神經網路庫層的預設單位。例如,若要建立具有 dropout 和 ReLU 激活函數的單層網路,您需要

  • 首先,建立一個 Block(透過子類別化 Module),其中包含一個具有 dropout 和 ReLU 激活函數的線性層。

  • 然後,在建立 Model(也是透過子類別化 Module)時,使用 Block 作為子 Module,其中 ModelBlock 和一個線性層組成。

Haiku 和 Flax 的 Module 物件之間有兩個根本差異

  • 無狀態與有狀態:

    • haiku.Module 實例是無狀態的。這表示變數是從純粹函數式的 Module.init() 呼叫返回,並單獨管理。

    • 然而,flax.nnx.Module 擁有其變數作為此 Python 物件的屬性。

  • 惰性與及早:

    • haiku.Module 僅在使用者呼叫模型時實際看到輸入時(惰性)才配置空間來建立變數。

    • flax.nnx.Module 實例會在實例化時立即建立變數,然後才看到範例輸入(及早)。

import haiku as hk

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class Model(hk.Module):
  def __init__(self, dmid: int, dout: int, name=None):
    super().__init__(name=name)
    self.dmid = dmid
    self.dout = dout

  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = hk.Linear(self.dout)(x)
    return x
from flax import nnx

class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x):
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x

class Model(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
    self.block = Block(din, dmid, rngs=rngs)
    self.linear = nnx.Linear(dmid, dout, rngs=rngs)


  def __call__(self, x):
    x = self.block(x)
    x = self.linear(x)
    return x

變數創建#

本節說明如何實例化模型並初始化其參數。

  • 若要為 Haiku 模型產生模型參數,您需要將其放在正向函數內,並使用 haiku.transform 使其成為純函數式。這會產生 JAX 陣列jax.Array 資料類型)的巢狀字典,以便單獨攜帶和維護。

  • 在 Flax NNX 中,當您實例化模型時,會自動初始化模型參數,且變數(nnx.Variable 物件)會儲存在 nnx.Module(或其子模組)內,作為屬性。您仍然需要為其提供一個 虛擬亂數產生器 (PRNG) 金鑰,但該金鑰會包裝在 nnx.Rngs 類別中並儲存在內部,並在需要時產生更多 PRNG 金鑰。

如果您想以無狀態、類似字典的方式存取 Flax 模型參數以進行檢查點儲存或模型手術,請查看 Flax NNX 分割/合併 APInnx.split / nnx.merge)。

def forward(x, training: bool):
  return Model(256, 10)(x, training)

model = hk.transform(forward)
sample_x = jnp.ones((1, 784))
params = model.init(jax.random.key(0), sample_x, training=False)


assert params['model/linear']['b'].shape == (10,)
assert params['model/block/linear']['w'].shape == (784, 256)
...


model = Model(784, 256, 10, rngs=nnx.Rngs(0))


# Parameters were already initialized during model instantiation.

assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)

訓練步驟和編譯#

本節涵蓋如何使用 JAX 即時編譯來編寫訓練步驟並進行編譯。

當編譯訓練步驟時

  • Haiku 使用 @jax.jit(一種 JAX 轉換)來編譯純函數式的訓練步驟。

  • Flax NNX 使用 @nnx.jit(一種 Flax NNX 轉換)(與 JAX 轉換類似的多種轉換 API 之一,但同時也與 Flax 物件搭配使用效果良好)。雖然 jax.jit 僅接受具有純粹無狀態引數的函數,但 flax.nnx.jit 允許引數為有狀態的模組。這大幅減少了訓練步驟所需的程式碼行數。

當取得梯度時

  • 類似地,Haiku 使用 jax.grad(一種用於 自動微分的 JAX 轉換)來傳回原始梯度字典。

  • 同時,Flax NNX 使用 flax.nnx.grad(一種 Flax NNX 轉換)以 flax.nnx.State 字典的形式傳回 Flax NNX 模組的梯度。如果您想搭配 Flax NNX 使用正規的 jax.grad,您需要使用分割/合併 API

對於最佳化器

  • 如果您已經使用 Optax 最佳化器,例如 optax.adamw(而不是此處顯示的原始 jax.tree.map 計算)搭配 Haiku,請查看 Flax 基礎知識指南中的 flax.nnx.Optimizer 範例,以了解更簡潔的訓練和更新模型方式。

每次訓練步驟期間的模型更新

  • Haiku 訓練步驟需要傳回參數的 JAX pytree,作為下一步的輸入。

  • Flax NNX 訓練步驟不需要傳回任何內容,因為 model 已在 nnx.jit 內就地更新。

  • 此外,nnx.Module 物件是有狀態的,且 Module 會自動追蹤其中的數個項目,例如 PRNG 金鑰和 flax.nnx.BatchNorm 統計資料。這就是為何您不需要在每個步驟中明確傳入 PRNG 金鑰的原因。另請注意,您可以使用 flax.nnx.reseed 來重設其底層 PRNG 狀態。

dropout 行為

  • 在 Haiku 中,您需要明確定義並傳入 training 引數來切換 haiku.dropout,並確保僅在 training=True 時才會發生隨機 dropout。

  • 在 Flax NNX 中,您可以呼叫 model.train() (flax.nnx.Module.train()) 來自動將 flax.nnx.Dropout 切換至訓練模式。相反地,您可以呼叫 model.eval() (flax.nnx.Module.eval()) 來關閉訓練模式。您可以參考其 API 參考,以深入瞭解 flax.nnx.Module.train 的功能。

...

@jax.jit
def train_step(key, params, inputs, labels):
  def loss_fn(params):
    logits = model.apply(
      params, key,
      inputs, training=True # <== inputs

    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)


  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params
model.train() # set deterministic=False

@nnx.jit
def train_step(model, inputs, labels):
  def loss_fn(model):
    logits = model(

      inputs, # <== inputs

    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = nnx.grad(loss_fn)(model)
  _, params, rest = nnx.split(model, nnx.Param, ...)
  params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
  nnx.update(model, nnx.GraphState.merge(params, rest))

處理非參數狀態#

Haiku 會區分可訓練的參數和模型追蹤的所有其他資料(「狀態」)。例如,批次標準化中使用的批次統計資訊會被視為狀態。具有狀態的模型需要使用 hk.transform_with_state 進行轉換,以便其 .init() 回傳參數和狀態。

在 Flax 中,沒有如此強烈的區別 - 它們都是 nnx.Variable 的子類別,並且被模組視為其屬性。參數是名為 nnx.Param 的子類別的實例,而批次統計資訊可以是另一個名為 nnx.BatchStat 的子類別。您可以使用 nnx.split 來快速提取特定變數類型的所有資料。

讓我們透過採用上面的 Block 定義,但將 dropout 替換為 BatchNorm 來看看這個範例。

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features



  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.BatchNorm(
      create_scale=True, create_offset=True, decay_rate=0.99
    )(x, is_training=training)
    x = jax.nn.relu(x)
    return x

def forward(x, training: bool):
  return Model(256, 10)(x, training)
model = hk.transform_with_state(forward)

sample_x = jnp.ones((1, 784))
params, batch_stats = model.init(jax.random.key(0), sample_x, training=True)
class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(
      num_features=out_features, momentum=0.99, rngs=rngs
    )

  def __call__(self, x):
    x = self.linear(x)
    x = self.batchnorm(x)


    x = jax.nn.relu(x)
    return x



model = Block(4, 4, rngs=nnx.Rngs(0))

model.linear.kernel   # Param(value=...)
model.batchnorm.mean  # BatchStat(value=...)

Flax 會考慮可訓練參數和其他資料之間的差異。nnx.grad 將只針對 nnx.Param 變數計算梯度,因此會自動跳過 batchnorm 陣列。因此,對於具有此模型的 Flax NNX,訓練步驟看起來會相同。

使用多種方法#

在本節中,您將學習如何在 Haiku 和 Flax 中使用多種方法。舉例來說,您將實作一個具有三種方法的自動編碼器模型:encodedecode__call__

在 Haiku 中,您需要使用 hk.multi_transform 來明確定義如何初始化模型以及它可以呼叫哪些方法(此處為 encodedecode)。請注意,您仍然需要定義一個 __call__,它會啟動兩個層以延遲初始化所有模型參數。

在 Flax 中,它更簡單,因為您在 __init__ 中初始化參數,並且可以直接使用 nnx.Module 方法 encodedecode

class AutoEncoder(hk.Module):

  def __init__(self, embed_dim: int, output_dim: int, name=None):
    super().__init__(name=name)
    self.encoder = hk.Linear(embed_dim, name="encoder")
    self.decoder = hk.Linear(output_dim, name="decoder")

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

def forward():
  module = AutoEncoder(256, 784)
  init = lambda x: module(x)
  return init, (module.encode, module.decode)

model = hk.multi_transform(forward)
params = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):

  def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):

    self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
    self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)











model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))
...

參數結構如下

...


{
    'auto_encoder/~/decoder': {
        'b': (784,),
        'w': (256, 784)
    },
    'auto_encoder/~/encoder': {
        'b': (256,),
        'w': (784, 256)
    }
}
_, params, _ = nnx.split(model, nnx.Param, ...)

params
State({
  'decoder': {
    'bias': VariableState(type=Param, value=(784,)),
    'kernel': VariableState(type=Param, value=(256, 784))
  },
  'encoder': {
    'bias': VariableState(type=Param, value=(256,)),
    'kernel': VariableState(type=Param, value=(784, 256))
  }
})

要呼叫這些自訂方法

  • 在 Haiku 中,您需要解耦 .apply 函式,以在呼叫它之前提取您的方法。

  • 在 Flax 中,您可以直接呼叫該方法。

encode, decode = model.apply
z = encode(params, None, x=jnp.ones((1, 784)))
...
z = model.encode(jnp.ones((1, 784)))

轉換#

Haiku 和 Flax 轉換都提供它們自己的一組轉換,這些轉換以可以使用 Module 物件的方式包裝 JAX 轉換

有關 Flax 轉換的更多資訊,請查看轉換指南

讓我們從一個範例開始

  • 首先,定義一個 RNNCell Module,它將包含 RNN 單一步驟的邏輯。

  • 定義一個 initial_state 方法,該方法將用於初始化 RNN 的狀態(又名 carry)。與 jax.lax.scan (API 文件) 一樣,RNNCell.__call__ 方法將是一個接受 carry 和輸入,並回傳新的 carry 和輸出的函式。在此情況下,carry 和輸出是相同的。

class RNNCell(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = hk.Linear(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
  def __init__(self, input_size, hidden_size, rngs):
    self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = self.linear(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))

接下來,我們將定義一個 RNN 模組,它將包含整個 RNN 的邏輯。在這兩種情況下,我們都使用程式庫的 scan 呼叫來在輸入序列上執行 RNNCell

唯一的區別在於,Flax nnx.scan 允許您在引數 in_axesout_axes 中指定要重複的軸,這些軸將轉發到基礎的 `jax.lax.scan<https://jax.dev.org.tw/en/latest/_autosummary/jax.lax.scan.html>`__,而在 Haiku 中,您需要明確地轉換輸入和輸出。

class RNN(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, x):
    cell = RNNCell(self.hidden_size)
    carry = cell.initial_state(x.shape[0])
    carry, y = hk.scan(
      cell, carry,
      jnp.swapaxes(x, 1, 0)
    )
    y = jnp.swapaxes(y, 0, 1)
    return y
class RNN(nnx.Module):
  def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
    self.hidden_size = hidden_size
    self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)

  def __call__(self, x):
    scan_fn = lambda carry, cell, x: cell(carry, x)
    carry = self.cell.initial_state(x.shape[0])
    carry, y = nnx.scan(
      scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
    )(carry, self.cell, x)

    return y

掃描圖層#

大多數 Haiku 轉換應該看起來與 Flax 類似,因為它們都包裝了它們的 JAX 對應項,但是掃描圖層的使用案例是例外。

掃描圖層是一種技術,您可以在其中通過 N 個重複圖層的序列來執行輸入,並將每個圖層的輸出作為下一個圖層的輸入傳遞。此模式可以顯著減少大型模型的編譯時間。在下面的範例中,您將在頂層 MLP Module 中重複 Block Module 5 次。

在 Haiku 中,我們像往常一樣定義 Block 模組,然後在 MLP 內部,我們將在 stack_block 函式上使用 hk.experimental.layer_stack 來建立 Block 模組的堆疊。相同的程式碼將在初始化時建立 5 層參數,並在呼叫時將輸入通過它們。

在 Flax 中,模型初始化和呼叫程式碼是完全解耦的,因此我們使用 nnx.vmap 轉換來初始化基礎的 Block 參數,並使用 nnx.scan 轉換來在它們中執行模型輸入。

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class MLP(hk.Module):
  def __init__(self, features: int, num_layers: int, name=None):
      super().__init__(name=name)
      self.features = features
      self.num_layers = num_layers





  def __call__(self, x, training: bool):

    @hk.experimental.layer_stack(self.num_layers)
    def stack_block(x):
      return Block(self.features)(x, training)

    stack = hk.experimental.layer_stack(self.num_layers)
    return stack_block(x)

def forward(x, training: bool):
  return MLP(64, num_layers=5)(x, training)
model = hk.transform(forward)

sample_x = jnp.ones((1, 64))
params = model.init(jax.random.key(0), sample_x, training=False)
class Block(nnx.Module):
  def __init__(self, input_dim, features, rngs):
    self.linear = nnx.Linear(input_dim, features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x: jax.Array):  # No need to require a second input!
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x   # No need to return a second output!

class MLP(nnx.Module):
  def __init__(self, features, num_layers, rngs):
    @nnx.split_rngs(splits=num_layers)
    @nnx.vmap(in_axes=(0,), out_axes=0)
    def create_block(rngs: nnx.Rngs):
      return Block(features, features, rngs=rngs)

    self.blocks = create_block(rngs)
    self.num_layers = num_layers

  def __call__(self, x):
    @nnx.split_rngs(splits=self.num_layers)
    @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
    def forward(x, model):
      x = model(x)
      return x

    return forward(x, self.blocks)



model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))

在上面的 Flax 範例中,還有一些其他細節需要說明

  • `@nnx.split_rngs` 修飾器: Flax 轉換(如同它們的 JAX 對應項)完全與 PRNG 狀態無關,並且依賴於 PRNG 金鑰的輸入。 nnx.split_rngs 修飾器允許您在將 nnx.Rngs 傳遞給修飾函式之前分割它們,然後在之後「降低」它們,以便它們可以在外部使用。

    • 在這裡,您分割了 PRNG 金鑰,因為如果每個內部操作都需要自己的金鑰,則 jax.vmapjax.lax.scan 需要 PRNG 金鑰的清單。因此,對於 MLP 內的 5 層,您將在向下傳遞到 JAX 轉換之前,從其引數分割並提供 5 個不同的 PRNG 金鑰。

    • 請注意,實際上,create_block() 知道它需要建立 5 個圖層恰恰是因為它看到了 5 個 PRNG 金鑰,因為 in_axes=(0,) 表示 vmap 將查看第一個引數的第一個維度,以了解它將對應的大小。

    • 對於 forward() 也是如此,它會查看第一個引數(又名 model)內的變數,以找出它需要掃描多少次。nnx.split_rngs 實際上會在這裡分割 model 內部的 PRNG 狀態。(如果 Block Module 沒有 dropout,則您不需要 nnx.split_rngs 行,因為它無論如何都不會消耗任何 PRNG 金鑰。)

  • 為什麼 Flax 中的 Block 模組不需要接收和回傳額外的虛擬值: jax.lax.scan (API 文件要求其函式回傳兩個輸入 - carry 和堆疊輸出。在這種情況下,我們沒有使用後者。Flax 簡化了此過程,因此如果您設定 out_axes=nnx.Carry 而不是預設的 (nnx.Carry, 0),現在您可以選擇忽略第二個輸出。

    • 這是 Flax NNX 轉換與 JAX 轉換 API 不同的少數情況之一。

在上面的 Flax 範例中,程式碼行數較多,但它們更精確地表達了每個時間點發生的事情。由於 Flax 的轉換方式與 JAX 轉換 API 更為接近,建議在使用其 Flax NNX 對應項之前,先充分理解底層的 JAX 轉換

現在檢查雙方的變數 PyTree。

...


{
    'mlp/__layer_stack_no_per_layer/block/linear': {
        'b': (5, 64),
        'w': (5, 64, 64)
    }
}



...
_, params, _ = nnx.split(model, nnx.Param, ...)

params
State({
  'blocks': {
    'linear': {
      'bias': VariableState(type=Param, value=(5, 64)),
      'kernel': VariableState(type=Param, value=(5, 64, 64))
    }
  }
})

頂層 Haiku 函數 vs 頂層 Flax 模組#

在 Haiku 中,可以透過使用原始的 hk.{get,set}_{parameter,state} 來定義/存取模型參數和狀態,將整個模型寫成單一函數。通常會將頂層的「模組」寫成一個函數。

Flax 團隊建議採用更以模組為中心的方法,使用 __call__ 來定義前向函數。在 Flax 模組中,可以使用常規的 Python 類別語義來設定和存取參數和變數。

...


def forward(x):


  counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
  multiplier = hk.get_parameter(
    'multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones
  )

  output = x + multiplier * counter

  hk.set_state("counter", counter + 1)
  return output

model = hk.transform_with_state(forward)

params, state = model.init(jax.random.key(0), jnp.ones((1, 64)))
class Counter(nnx.Variable):
  pass

class FooModule(nnx.Module):

  def __init__(self, rngs):
    self.counter = Counter(jnp.ones((), jnp.int32))
    self.multiplier = nnx.Param(
      nnx.initializers.ones(rngs.params(), [1,], jnp.float32)
    )
  def __call__(self, x):
    output = x + self.multiplier * self.counter.value

    self.counter.value += 1
    return output

model = FooModule(rngs=nnx.Rngs(0))

_, params, counter = nnx.split(model, nnx.Param, Counter)