從 Haiku 遷移到 Flax#

本指南將逐步說明將 Haiku 模型遷移到 Flax 的流程,並重點說明這兩個函式庫的差異。

基本範例#

要在 Haiku 和 Flax 中建立自訂模組,你會在兩個函式庫建立 模組 基礎類別,創造子類別。不過,Haiku 類別使用一般 __init__ 方法,而 Flax 類別則是 資料類別,表示你要定義一些類別屬性,這些屬性會用於自動產生建構函式。此外,所有 Flax 模組都會接受 名稱 參數,而不需要自行定義;但是,在 Haiku 中,必須在建構函式簽署中明確定義 名稱,並傳遞至父類別建構函式。

import haiku as hk

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class Model(hk.Module):
  def __init__(self, dmid: int, dout: int, name=None):
    super().__init__(name=name)
    self.dmid = dmid
    self.dout = dout

  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = hk.Linear(self.dout)(x)
    return x
import flax.linen as nn

class Block(nn.Module):
  features: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5, deterministic=not training)(x)
    x = jax.nn.relu(x)
    return x

class Model(nn.Module):
  dmid: int
  dout: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = nn.Dense(self.dout)(x)
    return x

在兩個函式庫中,__call__ 方法看來很相似;不過,在 Flax 中,你必須使用 @nn.compact 裝飾器,才能內嵌定義子模組。在 Haiku 中,這是預設的行為。

現在,Haiku 和 Flax 的一大不同之處,在於你組建模型的方式。在 Haiku 中,你會對呼叫你模組的函式使用 hk.transformtransform 會回傳一個物件,具有 initapply 方法。在 Flax 中,你只需執行你模組的實體化作業。

def forward(x, training: bool):
  return Model(256, 10)(x, training)

model = hk.transform(forward)
...


model = Model(256, 10)

為了在這兩個函式庫取得模型參數,你能透過一組 random.key 和一些用於執行模型的輸入,來使用具有 init 方法。這裡的主要差異在於,Flax 會回傳一個從收集名稱到巢狀陣列字典的對應,參數 只是這些可能收集項目的其中一個。在 Haiku 中,你能直接取得 參數 結構。

sample_x = jax.numpy.ones((1, 784))
params = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
params = variables["params"]

有一點非常重要,需要注意的是,在 Flax 中,參數結構是階層式的,每個巢狀模組為一層,最後一層則為參數名稱。在 Haiku 中,參數結構是一個具有兩個階層層級的 Python 字典:完全限定的模組名稱對應至參數名稱。模組名稱包含所有巢狀模組的 / 分隔字串路徑。

...
{
  'model/block/linear': {
    'b': (256,),
    'w': (784, 256),
  },
  'model/linear': {
    'b': (10,),
    'w': (256, 10),
  }
}
...
FrozenDict({
  Block_0: {
    Dense_0: {
      bias: (256,),
      kernel: (784, 256),
    },
  },
  Dense_0: {
    bias: (10,),
    kernel: (256, 10),
  },
})

在這兩個框架中訓練期間,你會將參數結構傳遞給 apply 方法,以執行前進傳遞。由於我們使用輟學,在兩種情況下,都必須提供 金鑰apply,以產生隨機輟學遮罩。

def train_step(key, params, inputs, labels):
  def loss_fn(params):
      logits = model.apply(
        params,
        key,
        inputs, training=True # <== inputs
      )
      return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)
  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params
def train_step(key, params, inputs, labels):
  def loss_fn(params):
      logits = model.apply(
        {'params': params},
        inputs, training=True, # <== inputs
        rngs={'dropout': key}
      )
      return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)
  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params

最明顯的不同點是 Flax 中必須在含有 params 鍵的詞典中傳遞參數,並且在含有 dropout 鍵的詞典中傳遞鍵。這是因為在 Flax 中可以有多種類型的模型狀態及隨機狀態。在 Haiku 中,只需直接傳遞參數和鍵。

狀態處理#

現在來看一下這兩個函式庫如何處理可變狀態。我們將採用與先前相同的模型,但現在會使用批次正規化(BatchNorm)取代 Dropout。

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.BatchNorm(
      create_scale=True, create_offset=True, decay_rate=0.99
    )(x, is_training=training)
    x = jax.nn.relu(x)
    return x
class Block(nn.Module):
  features: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.features)(x)
    x = nn.BatchNorm(
      momentum=0.99
    )(x, use_running_average=not training)
    x = jax.nn.relu(x)
    return x

因為兩者都提供批次正規化層,所以這段程式碼非常類似。最明顯的不同點是 Haiku 使用 is_training 控制是否更新執行中的統計資料,而 Flax 使用 use_running_average 達到相同目的。

要在 Haiku 中實例化狀態模型,請使用 hk.transform_with_state,其會變更 initapply 的簽章以接受並傳回狀態。與先前相同,在 Flax 中可以直接建構模組。

def forward(x, training: bool):
  return Model(256, 10)(x, training)

model = hk.transform_with_state(forward)
...


model = Model(256, 10)

要初始化參數和狀態,只需像先前一樣呼叫 init 方法。但是,您現在會在 Haiku 中取得 state 作為第二個傳回值,而在 Flax 中,您會在 variables 詞典中取得新的 batch_stats 集合。請注意,由於 hk.BatchNorm 僅在 is_training=True 時才初始化批次統計資料,因此我們必須在初始化含有 hk.BatchNorm 層的 Haiku 模型參數時,將 training=True 設為真。在 Flax 中,我們可以照常將 training=False 設定為假。

sample_x = jax.numpy.ones((1, 784))
params, state = model.init(
  random.key(0),
  sample_x, training=True # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
params, batch_stats = variables["params"], variables["batch_stats"]

一般來說,在 Flax,您可能會在 variables 字典中找到其他狀態集合,例如 cache 用於自迴歸轉換器模型、 intermediates 用於使用 Module.sow 新增的中間值,或由自訂層定義的其他集合名稱。Haiku 僅區分 params(在執行 apply 時不會變化的變數)和 state(在執行 apply 時會變化的變數)。

現在,這兩個架構的訓練看來非常相似,因為您使用相同的 apply 方法來執行前向傳遞。在 Haiku 中,現在傳遞 state 作為 apply 的第二個引數,並得到新的狀態作為第二個傳回值。在 Flax 中,您會將 batch_stats 新增為輸入字典的新金鑰,並將 updates 變數字典作為第二個傳回值。

def train_step(params, state, inputs, labels):
  def loss_fn(params):
    logits, new_state = model.apply(
      params, state,
      None, # <== rng
      inputs, training=True # <== inputs
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, new_state

  grads, new_state = jax.grad(loss_fn, has_aux=True)(params)
  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params, new_state
def train_step(params, batch_stats, inputs, labels):
  def loss_fn(params):
    logits, updates = model.apply(
      {'params': params, 'batch_stats': batch_stats},
      inputs, training=True, # <== inputs
      mutable='batch_stats',
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, updates["batch_stats"]

  grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params)
  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params, batch_stats

一個主要的不同點在於,在 Flax,一個狀態集合可以是可變的或不可變的。在 init 期間,所有集合預設都是可變的,然而,在 apply 期間,您必須明確指定哪些集合是可變的。在此範例中,我們指定 batch_stats 是可變的。如果可變集合不止一個,這裡傳遞一個字串,但也可以給予一個清單。如果沒有執行此操作,當嘗試變異 batch_stats 時,執行階段會產生錯誤。此外,當 mutable 除非是 Falseupdates 字典會作為 apply 的第二個傳回值傳回,否則僅傳回模型輸出。Haiku 藉由擁有 params(不可變)和 state(可變)來區分可變與不可變,並使用 hk.transformhk.transform_with_state

使用多種方法#

在此區段,我們將探討如何在 Haiku 與 Flax 中使用多種方法。舉例來說,我們將實作一個包含三個方法的自動編碼器模型:encodedecode__call__

在 Haiku 中,我們只要將 encodedecode 需要的子模組直接定義在 __init__ 中即可,在此情況下,每個子模組只要使用一個 Linear 圖層。在 Flax 中,我們會在 setup 中預先定義一個 encoderdecoder 模組,並分別在 encodedecode 中使用它們。

class AutoEncoder(hk.Module):


  def __init__(self, embed_dim: int, output_dim: int, name=None):
    super().__init__(name=name)
    self.encoder = hk.Linear(embed_dim, name="encoder")
    self.decoder = hk.Linear(output_dim, name="decoder")

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x
class AutoEncoder(nn.Module):
  embed_dim: int
  output_dim: int

  def setup(self):
    self.encoder = nn.Dense(self.embed_dim)
    self.decoder = nn.Dense(self.output_dim)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

請注意,在 Flax 中,setup 並不會在 __init__ 之後執行,而是會在呼叫 initapply 時執行。

現在,我們希望可以從 AutoEncoder 模型呼叫任何方法。在 Haiku 中,我們可以透過 hk.multi_transform 為一個模組定義多個 apply 方法。傳遞給 multi_transform 的函式定義如何初始化模組,以及要產生哪些不同的 apply 方法。

def forward():
  module = AutoEncoder(256, 784)
  init = lambda x: module(x)
  return init, (module.encode, module.decode)

model = hk.multi_transform(forward)
...




model = AutoEncoder(256, 784)

為了初始化模型中的參數,可以使用 init 來觸發 __call__ 方法,這個方法使用 encodedecode 方法。這將為模型建立所有必要的參數。

params = model.init(
  random.key(0),
  x=jax.numpy.ones((1, 784)),
)
...
variables = model.init(
  random.key(0),
  x=jax.numpy.ones((1, 784)),
)
params = variables["params"]

這會產生以下參數結構。

{
    'auto_encoder/~/decoder': {
        'b': (784,),
        'w': (256, 784)
    },
    'auto_encoder/~/encoder': {
        'b': (256,),
        'w': (784, 256)
    }
}
FrozenDict({
    decoder: {
        bias: (784,),
        kernel: (256, 784),
    },
    encoder: {
        bias: (256,),
        kernel: (784, 256),
    },
})

最後,讓我們來探討如何使用 apply 函式來呼叫 encode 方法。

encode, decode = model.apply
z = encode(
  params,
  None, # <== rng
  x=jax.numpy.ones((1, 784)),

)
...
z = model.apply(
  {"params": params},

  x=jax.numpy.ones((1, 784)),
  method="encode",
)

由於 Haiku 的 apply 函數是透過 hk.multi_transform 產生的,它會是兩個函數的元組,我們可以將其解封為 encodedecode 函數,這些函數對應到 AutoEncoder 模組上的方法。在 Flax 中,我們會將方法名稱傳入字串中,來呼叫 encode 方法。在這裡,值得注意的另一個區別是,在 Haiku 中,rng 需要明確傳入,即使該模組在 apply 期間不會使用任何隨機操作。在 Flax 中則無此必要(詳見Flax 中的隨機性和 PRNG)。此處 Haiku rng 已設為 None,但你也可以在 apply 函數上使用 hk.without_apply_rng 來移除 rng 參數。

提升變換#

Flax 和 Haiku 都提供一組變換,我們會將它們稱為提升變換,它們會以 JAX 變換的方式進行包裝,以便於搭配モジュール使用,並有時提供其他功能。在本節中,我們將探討如何在 Flax 和 Haiku 中使用提升版本的 scan 來實作一個簡單的 RNN 層。

首先,我們會定義一個 RNNCell 模組,其中會包含 RNN 的單個步驟邏輯。我們也會定義一個 initial_state 方法來初始化 RNN 的狀態(又稱 carry)。與 jax.lax.scan 類似,RNNCell.__call__ 方法會是一個函數,其接受輸入和 carry,並傳回新的 carry 和輸出。在本例中,carry 和輸出是相同的。

class RNNCell(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = hk.Linear(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nn.Module):
  hidden_size: int


  @nn.compact
  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = nn.Dense(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))

接下來,我們將定義一個 RNN 模組,其中將包含整個 RNN 的邏輯。在 Haiku 中,我們將首先初始化 RNNCell,然後使用它來建構 carry,最後使用 hk.scan 在輸入序列中執行 RNNCell。在 Flax 中,它的執行方式略有不同,我們將使用 nn.scan 來定義一個新的暫時類型,包裝 RNNCell。在此過程中,我們還會指定指示 nn.scan 廣播 params 彙集 (所有步驟共享相同的參數),並且不切割 params rng 串流 (因此所有步驟均使用相同的參數初始化),最後,我們將指定我們希望掃描在輸入的第二個軸線上執行,並同時沿著第二個軸線堆疊輸出。然後,我們將立即使用此暫時類型,建立一個提升的 RNNCell 實體,並使用它來建立 carry,並執行 __call__ 方法,這將透過序列進行 scan

class RNN(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, x):
    cell = RNNCell(self.hidden_size)
    carry = cell.initial_state(x.shape[0])
    carry, y = hk.scan(cell, carry, jnp.swapaxes(x, 1, 0))
    y = jnp.swapaxes(y, 0, 1)
    return y
class RNN(nn.Module):
  hidden_size: int


  @nn.compact
  def __call__(self, x):
    rnn = nn.scan(RNNCell, variable_broadcast='params', split_rngs={'params': False},
                  in_axes=1, out_axes=1)(self.hidden_size)
    carry = rnn.initial_state(x.shape[0])
    carry, y = rnn(carry, x)
    return y

一般而言,Flax 和 Haiku 之間提升轉換的主要差異在於,在 Haiku 中,提升轉換不會在狀態上執行,也就是說,Haiku 將處理 paramsstate,使其在轉換內外保持相同的形狀。在 Flax 中,提升轉換可以同時在變數彙集和 rng 串流上執行,使用者必須根據轉換的語意,定義每個轉換處理不同彙集的方式。

最後,讓我們快速檢視 RNN 模組如何在 Haiku 和 Flax 中使用。

def forward(x):
  return RNN(64)(x)

model = hk.without_apply_rng(hk.transform(forward))

params = model.init(
  random.key(0),
  x=jax.numpy.ones((3, 12, 32)),
)

y = model.apply(
  params,
  x=jax.numpy.ones((3, 12, 32)),
)
...


model = RNN(64)

variables = model.init(
  random.key(0),
  x=jax.numpy.ones((3, 12, 32)),
)
params = variables['params']
y = model.apply(
  {'params': params},
  x=jax.numpy.ones((3, 12, 32)),
)

與前幾節的範例相比,唯一顯著的變更在於,這次我們在 Haiku 中使用 hk.without_apply_rng,因此我們不必將 rng 參數傳遞為 Noneapply 方法。

掃描各層#

函數 scan 中一個非常重要的應用是反覆套用一系列層建立輸入,將每個層的輸出傳遞給下一個層的輸入。此項功能在減少大型模型編譯時間方面非常實用。舉例來說,我們將建立一個簡易的 Block 模組,並在 MLP 模組中使用,這個模型將套用 Block 模組 num_layers 次。

在 Haiku 中,我們依慣例定義 Block 模組,然後在 MLP 中將使用 hk.experimental.layer_stack 搭配 stack_block 函數建立 Block 模組堆疊。在 Flax 中,Block 的定義略有不同,__call__ 將會接受並回傳第二個虛擬輸入/輸出,在兩種情況下都是 None。在 MLP 中,我們將使用 nn.scan,如同前一範例,但透過設定 split_rngs={'params': True}variable_axes={'params': 0} 我們告訴 nn.scan 為每個步驟建立不同的參數,並沿第一個軸向切片 params 蒐集,有效執行如 Haiku 中 Block 模組的堆疊。

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class MLP(hk.Module):
  def __init__(self, features: int, num_layers: int, name=None):
      super().__init__(name=name)
      self.features = features
      self.num_layers = num_layers

  def __call__(self, x, training: bool):
    @hk.experimental.layer_stack(self.num_layers)
    def stack_block(x):
      return Block(self.features)(x, training)

    stack = hk.experimental.layer_stack(self.num_layers)
    return stack_block(x)
class Block(nn.Module):
  features: int
  training: bool

  @nn.compact
  def __call__(self, x, _):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5)(x, deterministic=not self.training)
    x = jax.nn.relu(x)
    return x, None

class MLP(nn.Module):
  features: int
  num_layers: int

  @nn.compact
  def __call__(self, x, training: bool):
    ScanBlock = nn.scan(
      Block, variable_axes={'params': 0}, split_rngs={'params': True},
      length=self.num_layers)

    y, _ = ScanBlock(self.features, training)(x, None)
    return y

注意在 Flax 中我們如何將 None 傳遞給 ScanBlock 作為第二個參數,並忽略其第二個輸出。它們代表每個步驟的輸入/輸出,但它們是 None,因為在這個情況中我們沒有輸入/輸出。

初始化每個模型與前一範例相同。在這個情況中,我們將指定使用 5 個層,每個層有 64 項特徵。

def forward(x, training: bool):
  return MLP(64, num_layers=5)(x, training)

model = hk.transform(forward)

sample_x = jax.numpy.ones((1, 64))
params = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
...
...


model = MLP(64, num_layers=5)

sample_x = jax.numpy.ones((1, 64))
variables = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
params = variables['params']

使用圖層掃描時,您應該注意的一件事是,所有圖層都融合為單一圖層,其參數在第一個軸線上具有額外的「圖層」維度。在這種情況下,所有參數的形狀都將以 (5, ...) 開頭,因為我們使用了 5 個圖層。

...
{
    'mlp/__layer_stack_no_per_layer/block/linear': {
        'b': (5, 64),
        'w': (5, 64, 64)
    }
}
...
FrozenDict({
    ScanBlock_0: {
        Dense_0: {
            bias: (5, 64),
            kernel: (5, 64, 64),
        },
    },
})

頂層 Haiku 函數與頂層 Flax 模組#

在 Haiku 中,可以使用原始 hk.{get,set}_{parameter,state} 定義/存取模型參數和狀態,將整個模型寫成單一函數。很常將頂層「模組」寫成函數。

Flax 團隊建議採用更以模組為中心的辦法,使用 __call__ 定義前向函數。對應的存取器將是 nn.module.paramnn.module.variable (前往 處理狀態 以查看有關集合的說明)。

def forward(x):


  counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
  multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
  output = x + multiplier * counter
  hk.set_state("counter", counter + 1)

  return output

model = hk.transform_with_state(forward)

params, state = model.init(random.key(0), jax.numpy.ones((1, 64)))
class FooModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32))
    multiplier = self.param('multiplier', nn.initializers.ones_init(), [1,], x.dtype)
    output = x + multiplier * counter.value
    if not self.is_initializing():  # otherwise model.init() also increases it
      counter.value += 1
    return output

model = FooModule()
variables = model.init(random.key(0), jax.numpy.ones((1, 64)))
params, counter = variables['params'], variables['counter']