從 Haiku 遷移至 Flax#
本指南示範 Haiku 和 Flax NNX 模型之間的差異,提供並排的範例程式碼,以協助您從 Haiku 遷移至 Flax NNX API。
如果您是 Flax NNX 的新手,請務必熟悉 Flax NNX 基礎知識,其中涵蓋 nnx.Module
系統、Flax 轉換,以及附帶範例的 函數式 API。
讓我們從一些匯入開始。
基本模組定義#
Haiku 和 Flax 都使用 Module
類別作為表達神經網路庫層的預設單位。例如,若要建立具有 dropout 和 ReLU 激活函數的單層網路,您需要
首先,建立一個
Block
(透過子類別化Module
),其中包含一個具有 dropout 和 ReLU 激活函數的線性層。然後,在建立
Model
(也是透過子類別化Module
)時,使用Block
作為子Module
,其中Model
由Block
和一個線性層組成。
Haiku 和 Flax 的 Module
物件之間有兩個根本差異
無狀態與有狀態:
haiku.Module
實例是無狀態的。這表示變數是從純粹函數式的Module.init()
呼叫返回,並單獨管理。然而,
flax.nnx.Module
擁有其變數作為此 Python 物件的屬性。
惰性與及早:
haiku.Module
僅在使用者呼叫模型時實際看到輸入時(惰性)才配置空間來建立變數。flax.nnx.Module
實例會在實例化時立即建立變數,然後才看到範例輸入(及早)。
import haiku as hk
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class Model(hk.Module):
def __init__(self, dmid: int, dout: int, name=None):
super().__init__(name=name)
self.dmid = dmid
self.dout = dout
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = hk.Linear(self.dout)(x)
return x
from flax import nnx
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x):
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x
class Model(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
self.block = Block(din, dmid, rngs=rngs)
self.linear = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = self.block(x)
x = self.linear(x)
return x
變數創建#
本節說明如何實例化模型並初始化其參數。
若要為 Haiku 模型產生模型參數,您需要將其放在正向函數內,並使用
haiku.transform
使其成為純函數式。這會產生 JAX 陣列(jax.Array
資料類型)的巢狀字典,以便單獨攜帶和維護。在 Flax NNX 中,當您實例化模型時,會自動初始化模型參數,且變數(
nnx.Variable
物件)會儲存在nnx.Module
(或其子模組)內,作為屬性。您仍然需要為其提供一個 虛擬亂數產生器 (PRNG) 金鑰,但該金鑰會包裝在nnx.Rngs
類別中並儲存在內部,並在需要時產生更多 PRNG 金鑰。
如果您想以無狀態、類似字典的方式存取 Flax 模型參數以進行檢查點儲存或模型手術,請查看 Flax NNX 分割/合併 API(nnx.split
/ nnx.merge
)。
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform(forward)
sample_x = jnp.ones((1, 784))
params = model.init(jax.random.key(0), sample_x, training=False)
assert params['model/linear']['b'].shape == (10,)
assert params['model/block/linear']['w'].shape == (784, 256)
...
model = Model(784, 256, 10, rngs=nnx.Rngs(0))
# Parameters were already initialized during model instantiation.
assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)
訓練步驟和編譯#
本節涵蓋如何使用 JAX 即時編譯來編寫訓練步驟並進行編譯。
當編譯訓練步驟時
Haiku 使用
@jax.jit
(一種 JAX 轉換)來編譯純函數式的訓練步驟。Flax NNX 使用
@nnx.jit
(一種 Flax NNX 轉換)(與 JAX 轉換類似的多種轉換 API 之一,但同時也與 Flax 物件搭配使用效果良好)。雖然jax.jit
僅接受具有純粹無狀態引數的函數,但flax.nnx.jit
允許引數為有狀態的模組。這大幅減少了訓練步驟所需的程式碼行數。
當取得梯度時
類似地,Haiku 使用
jax.grad
(一種用於 自動微分的 JAX 轉換)來傳回原始梯度字典。同時,Flax NNX 使用
flax.nnx.grad
(一種 Flax NNX 轉換)以flax.nnx.State
字典的形式傳回 Flax NNX 模組的梯度。如果您想搭配 Flax NNX 使用正規的jax.grad
,您需要使用分割/合併 API。
對於最佳化器
如果您已經使用 Optax 最佳化器,例如
optax.adamw
(而不是此處顯示的原始jax.tree.map
計算)搭配 Haiku,請查看 Flax 基礎知識指南中的flax.nnx.Optimizer
範例,以了解更簡潔的訓練和更新模型方式。
每次訓練步驟期間的模型更新
Haiku 訓練步驟需要傳回參數的 JAX pytree,作為下一步的輸入。
Flax NNX 訓練步驟不需要傳回任何內容,因為
model
已在nnx.jit
內就地更新。此外,
nnx.Module
物件是有狀態的,且Module
會自動追蹤其中的數個項目,例如 PRNG 金鑰和flax.nnx.BatchNorm
統計資料。這就是為何您不需要在每個步驟中明確傳入 PRNG 金鑰的原因。另請注意,您可以使用flax.nnx.reseed
來重設其底層 PRNG 狀態。
dropout 行為
在 Haiku 中,您需要明確定義並傳入
training
引數來切換haiku.dropout
,並確保僅在training=True
時才會發生隨機 dropout。在 Flax NNX 中,您可以呼叫
model.train()
(flax.nnx.Module.train()
) 來自動將flax.nnx.Dropout
切換至訓練模式。相反地,您可以呼叫model.eval()
(flax.nnx.Module.eval()
) 來關閉訓練模式。您可以參考其 API 參考,以深入瞭解flax.nnx.Module.train
的功能。
...
@jax.jit
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
params, key,
inputs, training=True # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
model.train() # set deterministic=False
@nnx.jit
def train_step(model, inputs, labels):
def loss_fn(model):
logits = model(
inputs, # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.GraphState.merge(params, rest))
處理非參數狀態#
Haiku 會區分可訓練的參數和模型追蹤的所有其他資料(「狀態」)。例如,批次標準化中使用的批次統計資訊會被視為狀態。具有狀態的模型需要使用 hk.transform_with_state
進行轉換,以便其 .init()
回傳參數和狀態。
在 Flax 中,沒有如此強烈的區別 - 它們都是 nnx.Variable
的子類別,並且被模組視為其屬性。參數是名為 nnx.Param
的子類別的實例,而批次統計資訊可以是另一個名為 nnx.BatchStat
的子類別。您可以使用 nnx.split
來快速提取特定變數類型的所有資料。
讓我們透過採用上面的 Block
定義,但將 dropout 替換為 BatchNorm
來看看這個範例。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.BatchNorm(
create_scale=True, create_offset=True, decay_rate=0.99
)(x, is_training=training)
x = jax.nn.relu(x)
return x
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform_with_state(forward)
sample_x = jnp.ones((1, 784))
params, batch_stats = model.init(jax.random.key(0), sample_x, training=True)
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.batchnorm = nnx.BatchNorm(
num_features=out_features, momentum=0.99, rngs=rngs
)
def __call__(self, x):
x = self.linear(x)
x = self.batchnorm(x)
x = jax.nn.relu(x)
return x
model = Block(4, 4, rngs=nnx.Rngs(0))
model.linear.kernel # Param(value=...)
model.batchnorm.mean # BatchStat(value=...)
Flax 會考慮可訓練參數和其他資料之間的差異。nnx.grad
將只針對 nnx.Param
變數計算梯度,因此會自動跳過 batchnorm
陣列。因此,對於具有此模型的 Flax NNX,訓練步驟看起來會相同。
使用多種方法#
在本節中,您將學習如何在 Haiku 和 Flax 中使用多種方法。舉例來說,您將實作一個具有三種方法的自動編碼器模型:encode
、decode
和 __call__
。
在 Haiku 中,您需要使用 hk.multi_transform
來明確定義如何初始化模型以及它可以呼叫哪些方法(此處為 encode
和 decode
)。請注意,您仍然需要定義一個 __call__
,它會啟動兩個層以延遲初始化所有模型參數。
在 Flax 中,它更簡單,因為您在 __init__
中初始化參數,並且可以直接使用 nnx.Module
方法 encode
和 decode
。
class AutoEncoder(hk.Module):
def __init__(self, embed_dim: int, output_dim: int, name=None):
super().__init__(name=name)
self.encoder = hk.Linear(embed_dim, name="encoder")
self.decoder = hk.Linear(output_dim, name="decoder")
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
def forward():
module = AutoEncoder(256, 784)
init = lambda x: module(x)
return init, (module.encode, module.decode)
model = hk.multi_transform(forward)
params = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):
def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))
...
參數結構如下
...
{
'auto_encoder/~/decoder': {
'b': (784,),
'w': (256, 784)
},
'auto_encoder/~/encoder': {
'b': (256,),
'w': (784, 256)
}
}
_, params, _ = nnx.split(model, nnx.Param, ...)
params
State({
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
},
'encoder': {
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
})
要呼叫這些自訂方法
在 Haiku 中,您需要解耦 .apply 函式,以在呼叫它之前提取您的方法。
在 Flax 中,您可以直接呼叫該方法。
encode, decode = model.apply
z = encode(params, None, x=jnp.ones((1, 784)))
...
z = model.encode(jnp.ones((1, 784)))
轉換#
Haiku 和 Flax 轉換都提供它們自己的一組轉換,這些轉換以可以使用 Module
物件的方式包裝 JAX 轉換。
有關 Flax 轉換的更多資訊,請查看轉換指南。
讓我們從一個範例開始
首先,定義一個
RNNCell
Module
,它將包含 RNN 單一步驟的邏輯。定義一個
initial_state
方法,該方法將用於初始化 RNN 的狀態(又名carry
)。與jax.lax.scan
(API 文件) 一樣,RNNCell.__call__
方法將是一個接受 carry 和輸入,並回傳新的 carry 和輸出的函式。在此情況下,carry 和輸出是相同的。
class RNNCell(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = hk.Linear(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
def __init__(self, input_size, hidden_size, rngs):
self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = self.linear(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
接下來,我們將定義一個 RNN
模組,它將包含整個 RNN 的邏輯。在這兩種情況下,我們都使用程式庫的 scan
呼叫來在輸入序列上執行 RNNCell
。
唯一的區別在於,Flax nnx.scan
允許您在引數 in_axes
和 out_axes
中指定要重複的軸,這些軸將轉發到基礎的 `jax.lax.scan<https://jax.dev.org.tw/en/latest/_autosummary/jax.lax.scan.html>`__,而在 Haiku 中,您需要明確地轉換輸入和輸出。
class RNN(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, x):
cell = RNNCell(self.hidden_size)
carry = cell.initial_state(x.shape[0])
carry, y = hk.scan(
cell, carry,
jnp.swapaxes(x, 1, 0)
)
y = jnp.swapaxes(y, 0, 1)
return y
class RNN(nnx.Module):
def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
self.hidden_size = hidden_size
self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)
def __call__(self, x):
scan_fn = lambda carry, cell, x: cell(carry, x)
carry = self.cell.initial_state(x.shape[0])
carry, y = nnx.scan(
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
)(carry, self.cell, x)
return y
掃描圖層#
大多數 Haiku 轉換應該看起來與 Flax 類似,因為它們都包裝了它們的 JAX 對應項,但是掃描圖層的使用案例是例外。
掃描圖層是一種技術,您可以在其中通過 N 個重複圖層的序列來執行輸入,並將每個圖層的輸出作為下一個圖層的輸入傳遞。此模式可以顯著減少大型模型的編譯時間。在下面的範例中,您將在頂層 MLP
Module
中重複 Block
Module
5 次。
在 Haiku 中,我們像往常一樣定義 Block
模組,然後在 MLP
內部,我們將在 stack_block
函式上使用 hk.experimental.layer_stack
來建立 Block
模組的堆疊。相同的程式碼將在初始化時建立 5 層參數,並在呼叫時將輸入通過它們。
在 Flax 中,模型初始化和呼叫程式碼是完全解耦的,因此我們使用 nnx.vmap
轉換來初始化基礎的 Block
參數,並使用 nnx.scan
轉換來在它們中執行模型輸入。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class MLP(hk.Module):
def __init__(self, features: int, num_layers: int, name=None):
super().__init__(name=name)
self.features = features
self.num_layers = num_layers
def __call__(self, x, training: bool):
@hk.experimental.layer_stack(self.num_layers)
def stack_block(x):
return Block(self.features)(x, training)
stack = hk.experimental.layer_stack(self.num_layers)
return stack_block(x)
def forward(x, training: bool):
return MLP(64, num_layers=5)(x, training)
model = hk.transform(forward)
sample_x = jnp.ones((1, 64))
params = model.init(jax.random.key(0), sample_x, training=False)
class Block(nnx.Module):
def __init__(self, input_dim, features, rngs):
self.linear = nnx.Linear(input_dim, features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x: jax.Array): # No need to require a second input!
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x # No need to return a second output!
class MLP(nnx.Module):
def __init__(self, features, num_layers, rngs):
@nnx.split_rngs(splits=num_layers)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_block(rngs: nnx.Rngs):
return Block(features, features, rngs=rngs)
self.blocks = create_block(rngs)
self.num_layers = num_layers
def __call__(self, x):
@nnx.split_rngs(splits=self.num_layers)
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def forward(x, model):
x = model(x)
return x
return forward(x, self.blocks)
model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))
在上面的 Flax 範例中,還有一些其他細節需要說明
`@nnx.split_rngs` 修飾器: Flax 轉換(如同它們的 JAX 對應項)完全與 PRNG 狀態無關,並且依賴於 PRNG 金鑰的輸入。
nnx.split_rngs
修飾器允許您在將nnx.Rngs
傳遞給修飾函式之前分割它們,然後在之後「降低」它們,以便它們可以在外部使用。在這裡,您分割了 PRNG 金鑰,因為如果每個內部操作都需要自己的金鑰,則
jax.vmap
和jax.lax.scan
需要 PRNG 金鑰的清單。因此,對於MLP
內的 5 層,您將在向下傳遞到 JAX 轉換之前,從其引數分割並提供 5 個不同的 PRNG 金鑰。請注意,實際上,
create_block()
知道它需要建立 5 個圖層恰恰是因為它看到了 5 個 PRNG 金鑰,因為in_axes=(0,)
表示vmap
將查看第一個引數的第一個維度,以了解它將對應的大小。對於
forward()
也是如此,它會查看第一個引數(又名model
)內的變數,以找出它需要掃描多少次。nnx.split_rngs
實際上會在這裡分割model
內部的 PRNG 狀態。(如果Block
Module
沒有 dropout,則您不需要nnx.split_rngs
行,因為它無論如何都不會消耗任何 PRNG 金鑰。)
為什麼 Flax 中的 Block 模組不需要接收和回傳額外的虛擬值:
jax.lax.scan
(API 文件要求其函式回傳兩個輸入 - carry 和堆疊輸出。在這種情況下,我們沒有使用後者。Flax 簡化了此過程,因此如果您設定out_axes=nnx.Carry
而不是預設的(nnx.Carry, 0)
,現在您可以選擇忽略第二個輸出。這是 Flax NNX 轉換與 JAX 轉換 API 不同的少數情況之一。
在上面的 Flax 範例中,程式碼行數較多,但它們更精確地表達了每個時間點發生的事情。由於 Flax 的轉換方式與 JAX 轉換 API 更為接近,建議在使用其 Flax NNX 對應項之前,先充分理解底層的 JAX 轉換。
現在檢查雙方的變數 PyTree。
...
{
'mlp/__layer_stack_no_per_layer/block/linear': {
'b': (5, 64),
'w': (5, 64, 64)
}
}
...
_, params, _ = nnx.split(model, nnx.Param, ...)
params
State({
'blocks': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
}
}
})
頂層 Haiku 函數 vs 頂層 Flax 模組#
在 Haiku 中,可以透過使用原始的 hk.{get,set}_{parameter,state}
來定義/存取模型參數和狀態,將整個模型寫成單一函數。通常會將頂層的「模組」寫成一個函數。
Flax 團隊建議採用更以模組為中心的方法,使用 __call__
來定義前向函數。在 Flax 模組中,可以使用常規的 Python 類別語義來設定和存取參數和變數。
...
def forward(x):
counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
multiplier = hk.get_parameter(
'multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones
)
output = x + multiplier * counter
hk.set_state("counter", counter + 1)
return output
model = hk.transform_with_state(forward)
params, state = model.init(jax.random.key(0), jnp.ones((1, 64)))
class Counter(nnx.Variable):
pass
class FooModule(nnx.Module):
def __init__(self, rngs):
self.counter = Counter(jnp.ones((), jnp.int32))
self.multiplier = nnx.Param(
nnx.initializers.ones(rngs.params(), [1,], jnp.float32)
)
def __call__(self, x):
output = x + self.multiplier * self.counter.value
self.counter.value += 1
return output
model = FooModule(rngs=nnx.Rngs(0))
_, params, counter = nnx.split(model, nnx.Param, Counter)