從 Haiku 遷移到 Flax#
本指南將逐步說明將 Haiku 模型遷移到 Flax 的流程,並重點說明這兩個函式庫的差異。
基本範例#
要在 Haiku 和 Flax 中建立自訂模組,你會在兩個函式庫建立 模組
基礎類別,創造子類別。不過,Haiku 類別使用一般 __init__
方法,而 Flax 類別則是 資料類別
,表示你要定義一些類別屬性,這些屬性會用於自動產生建構函式。此外,所有 Flax 模組都會接受 名稱
參數,而不需要自行定義;但是,在 Haiku 中,必須在建構函式簽署中明確定義 名稱
,並傳遞至父類別建構函式。
import haiku as hk
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class Model(hk.Module):
def __init__(self, dmid: int, dout: int, name=None):
super().__init__(name=name)
self.dmid = dmid
self.dout = dout
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = hk.Linear(self.dout)(x)
return x
import flax.linen as nn
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5, deterministic=not training)(x)
x = jax.nn.relu(x)
return x
class Model(nn.Module):
dmid: int
dout: int
@nn.compact
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = nn.Dense(self.dout)(x)
return x
在兩個函式庫中,__call__
方法看來很相似;不過,在 Flax 中,你必須使用 @nn.compact
裝飾器,才能內嵌定義子模組。在 Haiku 中,這是預設的行為。
現在,Haiku 和 Flax 的一大不同之處,在於你組建模型的方式。在 Haiku 中,你會對呼叫你模組的函式使用 hk.transform
,transform
會回傳一個物件,具有 init
和 apply
方法。在 Flax 中,你只需執行你模組的實體化作業。
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform(forward)
...
model = Model(256, 10)
為了在這兩個函式庫取得模型參數,你能透過一組 random.key
和一些用於執行模型的輸入,來使用具有 init
方法。這裡的主要差異在於,Flax 會回傳一個從收集名稱到巢狀陣列字典的對應,參數
只是這些可能收集項目的其中一個。在 Haiku 中,你能直接取得 參數
結構。
sample_x = jax.numpy.ones((1, 784))
params = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params = variables["params"]
有一點非常重要,需要注意的是,在 Flax 中,參數結構是階層式的,每個巢狀模組為一層,最後一層則為參數名稱。在 Haiku 中,參數結構是一個具有兩個階層層級的 Python 字典:完全限定的模組名稱對應至參數名稱。模組名稱包含所有巢狀模組的 /
分隔字串路徑。
...
{
'model/block/linear': {
'b': (256,),
'w': (784, 256),
},
'model/linear': {
'b': (10,),
'w': (256, 10),
}
}
...
FrozenDict({
Block_0: {
Dense_0: {
bias: (256,),
kernel: (784, 256),
},
},
Dense_0: {
bias: (10,),
kernel: (256, 10),
},
})
在這兩個框架中訓練期間,你會將參數結構傳遞給 apply
方法,以執行前進傳遞。由於我們使用輟學,在兩種情況下,都必須提供 金鑰
給 apply
,以產生隨機輟學遮罩。
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
params,
key,
inputs, training=True # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
最明顯的不同點是 Flax 中必須在含有 params
鍵的詞典中傳遞參數,並且在含有 dropout
鍵的詞典中傳遞鍵。這是因為在 Flax 中可以有多種類型的模型狀態及隨機狀態。在 Haiku 中,只需直接傳遞參數和鍵。
狀態處理#
現在來看一下這兩個函式庫如何處理可變狀態。我們將採用與先前相同的模型,但現在會使用批次正規化(BatchNorm)取代 Dropout。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.BatchNorm(
create_scale=True, create_offset=True, decay_rate=0.99
)(x, is_training=training)
x = jax.nn.relu(x)
return x
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.BatchNorm(
momentum=0.99
)(x, use_running_average=not training)
x = jax.nn.relu(x)
return x
因為兩者都提供批次正規化層,所以這段程式碼非常類似。最明顯的不同點是 Haiku 使用 is_training
控制是否更新執行中的統計資料,而 Flax 使用 use_running_average
達到相同目的。
要在 Haiku 中實例化狀態模型,請使用 hk.transform_with_state
,其會變更 init
和 apply
的簽章以接受並傳回狀態。與先前相同,在 Flax 中可以直接建構模組。
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform_with_state(forward)
...
model = Model(256, 10)
要初始化參數和狀態,只需像先前一樣呼叫 init
方法。但是,您現在會在 Haiku 中取得 state
作為第二個傳回值,而在 Flax 中,您會在 variables
詞典中取得新的 batch_stats
集合。請注意,由於 hk.BatchNorm
僅在 is_training=True
時才初始化批次統計資料,因此我們必須在初始化含有 hk.BatchNorm
層的 Haiku 模型參數時,將 training=True
設為真。在 Flax 中,我們可以照常將 training=False
設定為假。
sample_x = jax.numpy.ones((1, 784))
params, state = model.init(
random.key(0),
sample_x, training=True # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params, batch_stats = variables["params"], variables["batch_stats"]
一般來說,在 Flax,您可能會在 variables
字典中找到其他狀態集合,例如 cache
用於自迴歸轉換器模型、 intermediates
用於使用 Module.sow
新增的中間值,或由自訂層定義的其他集合名稱。Haiku 僅區分 params
(在執行 apply
時不會變化的變數)和 state
(在執行 apply
時會變化的變數)。
現在,這兩個架構的訓練看來非常相似,因為您使用相同的 apply
方法來執行前向傳遞。在 Haiku 中,現在傳遞 state
作為 apply
的第二個引數,並得到新的狀態作為第二個傳回值。在 Flax 中,您會將 batch_stats
新增為輸入字典的新金鑰,並將 updates
變數字典作為第二個傳回值。
def train_step(params, state, inputs, labels):
def loss_fn(params):
logits, new_state = model.apply(
params, state,
None, # <== rng
inputs, training=True # <== inputs
)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, new_state
grads, new_state = jax.grad(loss_fn, has_aux=True)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, new_state
def train_step(params, batch_stats, inputs, labels):
def loss_fn(params):
logits, updates = model.apply(
{'params': params, 'batch_stats': batch_stats},
inputs, training=True, # <== inputs
mutable='batch_stats',
)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, updates["batch_stats"]
grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, batch_stats
一個主要的不同點在於,在 Flax,一個狀態集合可以是可變的或不可變的。在 init
期間,所有集合預設都是可變的,然而,在 apply
期間,您必須明確指定哪些集合是可變的。在此範例中,我們指定 batch_stats
是可變的。如果可變集合不止一個,這裡傳遞一個字串,但也可以給予一個清單。如果沒有執行此操作,當嘗試變異 batch_stats
時,執行階段會產生錯誤。此外,當 mutable
除非是 False
, updates
字典會作為 apply
的第二個傳回值傳回,否則僅傳回模型輸出。Haiku 藉由擁有 params
(不可變)和 state
(可變)來區分可變與不可變,並使用 hk.transform
或 hk.transform_with_state
使用多種方法#
在此區段,我們將探討如何在 Haiku 與 Flax 中使用多種方法。舉例來說,我們將實作一個包含三個方法的自動編碼器模型:encode
、decode
及 __call__
。
在 Haiku 中,我們只要將 encode
與 decode
需要的子模組直接定義在 __init__
中即可,在此情況下,每個子模組只要使用一個 Linear
圖層。在 Flax 中,我們會在 setup
中預先定義一個 encoder
與 decoder
模組,並分別在 encode
與 decode
中使用它們。
class AutoEncoder(hk.Module):
def __init__(self, embed_dim: int, output_dim: int, name=None):
super().__init__(name=name)
self.encoder = hk.Linear(embed_dim, name="encoder")
self.decoder = hk.Linear(output_dim, name="decoder")
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
class AutoEncoder(nn.Module):
embed_dim: int
output_dim: int
def setup(self):
self.encoder = nn.Dense(self.embed_dim)
self.decoder = nn.Dense(self.output_dim)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
請注意,在 Flax 中,setup
並不會在 __init__
之後執行,而是會在呼叫 init
或 apply
時執行。
現在,我們希望可以從 AutoEncoder
模型呼叫任何方法。在 Haiku 中,我們可以透過 hk.multi_transform
為一個模組定義多個 apply
方法。傳遞給 multi_transform
的函式定義如何初始化模組,以及要產生哪些不同的 apply 方法。
def forward():
module = AutoEncoder(256, 784)
init = lambda x: module(x)
return init, (module.encode, module.decode)
model = hk.multi_transform(forward)
...
model = AutoEncoder(256, 784)
為了初始化模型中的參數,可以使用 init
來觸發 __call__
方法,這個方法使用 encode
與 decode
方法。這將為模型建立所有必要的參數。
params = model.init(
random.key(0),
x=jax.numpy.ones((1, 784)),
)
...
variables = model.init(
random.key(0),
x=jax.numpy.ones((1, 784)),
)
params = variables["params"]
這會產生以下參數結構。
{
'auto_encoder/~/decoder': {
'b': (784,),
'w': (256, 784)
},
'auto_encoder/~/encoder': {
'b': (256,),
'w': (784, 256)
}
}
FrozenDict({
decoder: {
bias: (784,),
kernel: (256, 784),
},
encoder: {
bias: (256,),
kernel: (784, 256),
},
})
最後,讓我們來探討如何使用 apply
函式來呼叫 encode
方法。
encode, decode = model.apply
z = encode(
params,
None, # <== rng
x=jax.numpy.ones((1, 784)),
)
...
z = model.apply(
{"params": params},
x=jax.numpy.ones((1, 784)),
method="encode",
)
由於 Haiku 的 apply
函數是透過 hk.multi_transform
產生的,它會是兩個函數的元組,我們可以將其解封為 encode
和 decode
函數,這些函數對應到 AutoEncoder
模組上的方法。在 Flax 中,我們會將方法名稱傳入字串中,來呼叫 encode
方法。在這裡,值得注意的另一個區別是,在 Haiku 中,rng
需要明確傳入,即使該模組在 apply
期間不會使用任何隨機操作。在 Flax 中則無此必要(詳見Flax 中的隨機性和 PRNG)。此處 Haiku rng
已設為 None
,但你也可以在 apply
函數上使用 hk.without_apply_rng
來移除 rng
參數。
提升變換#
Flax 和 Haiku 都提供一組變換,我們會將它們稱為提升變換,它們會以 JAX 變換的方式進行包裝,以便於搭配モジュール使用,並有時提供其他功能。在本節中,我們將探討如何在 Flax 和 Haiku 中使用提升版本的 scan
來實作一個簡單的 RNN 層。
首先,我們會定義一個 RNNCell
模組,其中會包含 RNN 的單個步驟邏輯。我們也會定義一個 initial_state
方法來初始化 RNN 的狀態(又稱 carry
)。與 jax.lax.scan
類似,RNNCell.__call__
方法會是一個函數,其接受輸入和 carry,並傳回新的 carry 和輸出。在本例中,carry 和輸出是相同的。
class RNNCell(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = hk.Linear(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = nn.Dense(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
接下來,我們將定義一個 RNN
模組,其中將包含整個 RNN 的邏輯。在 Haiku 中,我們將首先初始化 RNNCell
,然後使用它來建構 carry
,最後使用 hk.scan
在輸入序列中執行 RNNCell
。在 Flax 中,它的執行方式略有不同,我們將使用 nn.scan
來定義一個新的暫時類型,包裝 RNNCell
。在此過程中,我們還會指定指示 nn.scan
廣播 params
彙集 (所有步驟共享相同的參數),並且不切割 params
rng 串流 (因此所有步驟均使用相同的參數初始化),最後,我們將指定我們希望掃描在輸入的第二個軸線上執行,並同時沿著第二個軸線堆疊輸出。然後,我們將立即使用此暫時類型,建立一個提升的 RNNCell
實體,並使用它來建立 carry
,並執行 __call__
方法,這將透過序列進行 scan
。
class RNN(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, x):
cell = RNNCell(self.hidden_size)
carry = cell.initial_state(x.shape[0])
carry, y = hk.scan(cell, carry, jnp.swapaxes(x, 1, 0))
y = jnp.swapaxes(y, 0, 1)
return y
class RNN(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, x):
rnn = nn.scan(RNNCell, variable_broadcast='params', split_rngs={'params': False},
in_axes=1, out_axes=1)(self.hidden_size)
carry = rnn.initial_state(x.shape[0])
carry, y = rnn(carry, x)
return y
一般而言,Flax 和 Haiku 之間提升轉換的主要差異在於,在 Haiku 中,提升轉換不會在狀態上執行,也就是說,Haiku 將處理 params
和 state
,使其在轉換內外保持相同的形狀。在 Flax 中,提升轉換可以同時在變數彙集和 rng 串流上執行,使用者必須根據轉換的語意,定義每個轉換處理不同彙集的方式。
最後,讓我們快速檢視 RNN
模組如何在 Haiku 和 Flax 中使用。
def forward(x):
return RNN(64)(x)
model = hk.without_apply_rng(hk.transform(forward))
params = model.init(
random.key(0),
x=jax.numpy.ones((3, 12, 32)),
)
y = model.apply(
params,
x=jax.numpy.ones((3, 12, 32)),
)
...
model = RNN(64)
variables = model.init(
random.key(0),
x=jax.numpy.ones((3, 12, 32)),
)
params = variables['params']
y = model.apply(
{'params': params},
x=jax.numpy.ones((3, 12, 32)),
)
與前幾節的範例相比,唯一顯著的變更在於,這次我們在 Haiku 中使用 hk.without_apply_rng
,因此我們不必將 rng
參數傳遞為 None
給 apply
方法。
掃描各層#
函數 scan
中一個非常重要的應用是反覆套用一系列層建立輸入,將每個層的輸出傳遞給下一個層的輸入。此項功能在減少大型模型編譯時間方面非常實用。舉例來說,我們將建立一個簡易的 Block
模組,並在 MLP
模組中使用,這個模型將套用 Block
模組 num_layers
次。
在 Haiku 中,我們依慣例定義 Block
模組,然後在 MLP
中將使用 hk.experimental.layer_stack
搭配 stack_block
函數建立 Block
模組堆疊。在 Flax 中,Block
的定義略有不同,__call__
將會接受並回傳第二個虛擬輸入/輸出,在兩種情況下都是 None
。在 MLP
中,我們將使用 nn.scan
,如同前一範例,但透過設定 split_rngs={'params': True}
和 variable_axes={'params': 0}
我們告訴 nn.scan
為每個步驟建立不同的參數,並沿第一個軸向切片 params
蒐集,有效執行如 Haiku 中 Block
模組的堆疊。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class MLP(hk.Module):
def __init__(self, features: int, num_layers: int, name=None):
super().__init__(name=name)
self.features = features
self.num_layers = num_layers
def __call__(self, x, training: bool):
@hk.experimental.layer_stack(self.num_layers)
def stack_block(x):
return Block(self.features)(x, training)
stack = hk.experimental.layer_stack(self.num_layers)
return stack_block(x)
class Block(nn.Module):
features: int
training: bool
@nn.compact
def __call__(self, x, _):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
x = jax.nn.relu(x)
return x, None
class MLP(nn.Module):
features: int
num_layers: int
@nn.compact
def __call__(self, x, training: bool):
ScanBlock = nn.scan(
Block, variable_axes={'params': 0}, split_rngs={'params': True},
length=self.num_layers)
y, _ = ScanBlock(self.features, training)(x, None)
return y
注意在 Flax 中我們如何將 None
傳遞給 ScanBlock
作為第二個參數,並忽略其第二個輸出。它們代表每個步驟的輸入/輸出,但它們是 None
,因為在這個情況中我們沒有輸入/輸出。
初始化每個模型與前一範例相同。在這個情況中,我們將指定使用 5 個層,每個層有 64 項特徵。
def forward(x, training: bool):
return MLP(64, num_layers=5)(x, training)
model = hk.transform(forward)
sample_x = jax.numpy.ones((1, 64))
params = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
...
...
model = MLP(64, num_layers=5)
sample_x = jax.numpy.ones((1, 64))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params = variables['params']
使用圖層掃描時,您應該注意的一件事是,所有圖層都融合為單一圖層,其參數在第一個軸線上具有額外的「圖層」維度。在這種情況下,所有參數的形狀都將以 (5, ...)
開頭,因為我們使用了 5
個圖層。
...
{
'mlp/__layer_stack_no_per_layer/block/linear': {
'b': (5, 64),
'w': (5, 64, 64)
}
}
...
FrozenDict({
ScanBlock_0: {
Dense_0: {
bias: (5, 64),
kernel: (5, 64, 64),
},
},
})
頂層 Haiku 函數與頂層 Flax 模組#
在 Haiku 中,可以使用原始 hk.{get,set}_{parameter,state}
定義/存取模型參數和狀態,將整個模型寫成單一函數。很常將頂層「模組」寫成函數。
Flax 團隊建議採用更以模組為中心的辦法,使用 __call__ 定義前向函數。對應的存取器將是 nn.module.param 和 nn.module.variable (前往 處理狀態 以查看有關集合的說明)。
def forward(x):
counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
output = x + multiplier * counter
hk.set_state("counter", counter + 1)
return output
model = hk.transform_with_state(forward)
params, state = model.init(random.key(0), jax.numpy.ones((1, 64)))
class FooModule(nn.Module):
@nn.compact
def __call__(self, x):
counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32))
multiplier = self.param('multiplier', nn.initializers.ones_init(), [1,], x.dtype)
output = x + multiplier * counter.value
if not self.is_initializing(): # otherwise model.init() also increases it
counter.value += 1
return output
model = FooModule()
variables = model.init(random.key(0), jax.numpy.ones((1, 64)))
params, counter = variables['params'], variables['counter']