為何選擇 Flax NNX?#
在 2020 年,Flax 團隊發布了 Flax Linen API,以支援 JAX 上的模型研究,重點在於擴展性和效能。從那時起,我們從使用者身上學到了很多。團隊引入了一些已被證明對使用者有益的概念,例如
將變數組織成集合。
自動且有效率的偽隨機數生成器 (PRNG) 管理。
變數元數據,用於單程式多數據 (SPMD)註解、優化器元數據和其他使用案例。
Flax 團隊所做的選擇之一是使用函數式 (compact
) 語義,透過參數的延遲初始化進行神經網路編程。這使得實作程式碼簡潔,並使 Flax Linen API 與 Haiku 對齊。
然而,這也意味著 Flax 中模組和變數的語義是非 Python 式的,而且常常令人驚訝。它還導致了實作複雜性,並模糊了對神經網路進行轉換 (transforms)的核心概念。
介紹 Flax NNX#
快轉到 2024 年,Flax 團隊開發了 Flax NNX - 試圖保留使 Flax Linen 對使用者有用的功能,同時引入一些新的原則。Flax NNX 背後的中心思想是在 JAX 中引入參考語義。以下是其主要功能
NNX 是 Python 式的:模組的正規 Python 語義,包括對可變性和共享參考的支援。
NNX 很簡單:Flax Linen 中許多複雜的 API 要么使用 Python 慣用語簡化,要么完全移除。
更好的 JAX 整合:自定義 NNX 轉換採用與 JAX 轉換相同的 API。使用 NNX,更容易直接使用 JAX 轉換 (高階函數)。
以下是一個簡單的 Flax NNX 程式範例,說明了上述許多要點
from flax import nnx
import optax
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # Eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # Reference sharing.
@nnx.jit # Automatic state management for JAX transforms.
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # in-place updates
return loss
Flax NNX 對 Linen 的改進#
本文檔的其餘部分使用各種範例來演示 Flax NNX 如何改進 Flax Linen。
檢視#
第一個改進是 Flax NNX 模組是正規的 Python 物件。這意味著您可以輕鬆地建構和檢視 Module
物件。
另一方面,Flax Linen 模組不容易檢視和偵錯,因為它們是延遲的,這意味著某些屬性在建構時不可用,只能在執行時存取。
class Block(nn.Module):
def setup(self):
self.linear = nn.Dense(10)
block = Block()
try:
block.linear # AttributeError: "Block" object has no attribute "linear".
except AttributeError as e:
pass
...
class Block(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(5, 10, rngs=rngs)
block = Block(nnx.Rngs(0))
block.linear
# Linear(
# kernel=Param(
# value=Array(shape=(5, 10), dtype=float32)
# ),
# bias=Param(
# value=Array(shape=(10,), dtype=float32)
# ),
# ...
請注意,在上面的 Flax NNX 範例中,沒有形狀推斷 - 輸入和輸出形狀都必須提供給 Linear
nnx.Module
。這是一個權衡,允許更明確和可預測的行為。
執行計算#
在 Flax Linen 中,所有頂級計算都必須透過 flax.linen.Module.init
或 flax.linen.Module.apply
方法完成,並且參數或任何其他類型的狀態都作為單獨的結構處理。這在以下兩者之間產生了不對稱:1) 可以在 apply
內部執行,可以直接執行方法和其他 Module
物件的程式碼;以及 2) 在 apply
外部執行的程式碼,必須使用 apply
方法。
在 Flax NNX 中,沒有特殊的上下文,因為參數作為屬性保留,並且可以直接呼叫方法。這意味著您的 NNX 模組的 __init__
和 __call__
方法與其他類別方法沒有區別,而 Flax Linen 模組的 setup()
和 __call__
方法是特殊的。
Encoder = lambda: nn.Dense(10)
Decoder = lambda: nn.Dense(2)
class AutoEncoder(nn.Module):
def setup(self):
self.encoder = Encoder()
self.decoder = Decoder()
def __call__(self, x) -> jax.Array:
return self.decoder(self.encoder(x))
def encode(self, x) -> jax.Array:
return self.encoder(x)
x = jnp.ones((1, 2))
model = AutoEncoder()
params = model.init(random.key(0), x)['params']
y = model.apply({'params': params}, x)
z = model.apply({'params': params}, x, method='encode')
y = Decoder().apply({'params': params['decoder']}, z)
Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)
class AutoEncoder(nnx.Module):
def __init__(self, rngs):
self.encoder = Encoder(rngs)
self.decoder = Decoder(rngs)
def __call__(self, x) -> jax.Array:
return self.decoder(self.encoder(x))
def encode(self, x) -> jax.Array:
return self.encoder(x)
x = jnp.ones((1, 2))
model = AutoEncoder(nnx.Rngs(0))
y = model(x)
z = model.encode(x)
y = model.decoder(z)
在 Flax Linen 中,直接呼叫子模組是不可能的,因為它們沒有初始化。因此,您必須做的是建構一個新的實例,然後提供一個適當的參數結構。
但在 Flax NNX 中,您可以直接呼叫子模組而不會有任何問題。
狀態處理#
Flax Linen 眾所周知的複雜領域之一是狀態處理。當您使用 Dropout 層、BatchNorm 層或兩者時,您突然必須處理新的狀態,並使用它來配置 flax.linen.Module.apply
方法。
在 Flax NNX 中,狀態保存在 nnx.Module
內部,並且是可變的,這意味著可以直接呼叫它。
class Block(nn.Module):
train: bool
def setup(self):
self.linear = nn.Dense(10)
self.bn = nn.BatchNorm(use_running_average=not self.train)
self.dropout = nn.Dropout(0.1, deterministic=not self.train)
def __call__(self, x):
return nn.relu(self.dropout(self.bn(self.linear(x))))
x = jnp.ones((1, 5))
model = Block(train=True)
vs = model.init(random.key(0), x)
params, batch_stats = vs['params'], vs['batch_stats']
y, updates = model.apply(
{'params': params, 'batch_stats': batch_stats},
x,
rngs={'dropout': random.key(1)},
mutable=['batch_stats'],
)
batch_stats = updates['batch_stats']
class Block(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(5, 10, rngs=rngs)
self.bn = nnx.BatchNorm(10, rngs=rngs)
self.dropout = nnx.Dropout(0.1, rngs=rngs)
def __call__(self, x):
return nnx.relu(self.dropout(self.bn(self.linear(x))))
x = jnp.ones((1, 5))
model = Block(nnx.Rngs(0))
y = model(x)
...
Flax NNX 狀態處理的主要優點是,當您添加新的有狀態層時,不必更改訓練程式碼。
此外,在 Flax NNX 中,處理狀態的層也很容易實作。以下是一個簡化的 BatchNorm
層版本,每次呼叫時都會更新平均值和變異數。
class BatchNorm(nnx.Module):
def __init__(self, features: int, mu: float = 0.95):
# Variables
self.scale = nnx.Param(jax.numpy.ones((features,)))
self.bias = nnx.Param(jax.numpy.zeros((features,)))
self.mean = nnx.BatchStat(jax.numpy.zeros((features,)))
self.var = nnx.BatchStat(jax.numpy.ones((features,)))
self.mu = mu # Static
def __call__(self, x):
mean = jax.numpy.mean(x, axis=-1)
var = jax.numpy.var(x, axis=-1)
# ema updates
self.mean.value = self.mu * self.mean + (1 - self.mu) * mean
self.var.value = self.mu * self.var + (1 - self.mu) * var
# normalize and scale
x = (x - mean) / jax.numpy.sqrt(var + 1e-5)
return x * self.scale + self.bias
模型手術#
在 Flax Linen 中,由於以下兩個原因,模型手術在歷史上一直具有挑戰性
由於延遲初始化,不能保證您可以將子
Module
替換為新的子Module
。參數結構與
flax.linen.Module
結構分離,這意味著您必須手動保持它們同步。
在 Flax NNX 中,您可以根據 Python 語義直接替換子模組。由於參數是 nnx.Module
結構的一部分,它們永遠不會不同步。以下是如何實作 LoRA 層的範例,然後使用它來替換現有模型中的 Linear
層。
class LoraLinear(nn.Module):
linear: nn.Dense
rank: int
@nn.compact
def __call__(self, x: jax.Array):
A = self.param(random.normal, (x.shape[-1], self.rank))
B = self.param(random.normal, (self.rank, self.linear.features))
return self.linear(x) + x @ A @ B
try:
model = Block(train=True)
model.linear = LoraLinear(model.linear, rank=5) # <-- ERROR
lora_params = model.linear.init(random.key(1), x)
lora_params['linear'] = params['linear']
params['linear'] = lora_params
except AttributeError as e:
pass
class LoraParam(nnx.Param): pass
class LoraLinear(nnx.Module):
def __init__(self, linear, rank, rngs):
self.linear = linear
self.A = LoraParam(random.normal(rngs(), (linear.in_features, rank)))
self.B = LoraParam(random.normal(rngs(), (rank, linear.out_features)))
def __call__(self, x: jax.Array):
return self.linear(x) + x @ self.A @ self.B
rngs = nnx.Rngs(0)
model = Block(rngs)
model.linear = LoraLinear(model.linear, rank=5, rngs=rngs)
...
如上所示,在 Flax Linen 中,這種情況實際上不起作用,因為 linear
子 Module
不可用。但是,其餘程式碼提供了如何手動更新 params
結構的想法。
在 Flax Linen 中執行任意模型手術並不容易,目前 intercept_methods API 是對方法進行通用修補的唯一方法。但是這個 API 並不是很符合人體工學。
在 Flax NNX 中,要進行通用模型手術,您可以直接使用 nnx.iter_graph
,這比 Linen 簡單且容易得多。以下是一個範例,說明如何將模型中的所有 nnx.Linear
層替換為自定義的 LoraLinear
NNX 層。
rngs = nnx.Rngs(0)
model = Block(rngs)
for path, module in nnx.iter_graph(model):
if isinstance(module, nnx.Module):
for name, value in vars(module).items():
if isinstance(value, nnx.Linear):
setattr(module, name, LoraLinear(value, rank=5, rngs=rngs))
轉換#
Flax Linen 轉換非常強大,因為它們可以對模型的狀態進行細粒度的控制。但是,Flax Linen 轉換也有缺點,例如
它們公開了不屬於 JAX 的其他 API,使得它們的行為令人困惑,有時與 JAX 的對應物不同。這也限制了您與 JAX 轉換互動並跟上 JAX API 變化的方式。
它們作用於具有非常特定簽名的函數,即
flax.linen.Module
必須是第一個參數。它們接受其他
Module
物件作為參數,但不作為回傳值。
它們只能在
flax.linen.Module.apply
內部使用。
另一方面,Flax NNX 轉換 旨在與它們對應的 JAX 轉換 等效,但有一個例外 - 它們可以用於 Flax NNX 模組。這表示 Flax 轉換
具有與 JAX 轉換相同的 API。
可以在任何參數上接受 Flax NNX 模組,並且
nnx.Module
物件可以從它/它們回傳。可以在任何地方使用,包括訓練迴圈。
以下是一個使用 vmap
和 Flax NNX 的範例,它透過轉換 create_weights
函數來創建權重堆疊,該函數回傳一些 Weights
,並透過轉換 vector_dot
函數將該權重堆疊個別應用於一批輸入,該函數將 Weights
作為第一個參數,並將一批輸入作為第二個參數。
class Weights(nnx.Module):
def __init__(self, kernel: jax.Array, bias: jax.Array):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
def create_weights(seed: jax.Array):
return Weights(
kernel=random.uniform(random.key(seed), (2, 3)),
bias=jnp.zeros((3,)),
)
def vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
return x @ weights.kernel + weights.bias
seeds = jnp.arange(10)
weights = nnx.vmap(create_weights, in_axes=0, out_axes=0)(seeds)
x = jax.random.normal(random.key(1), (10, 2))
y = nnx.vmap(vector_dot, in_axes=(0, 0), out_axes=1)(weights, x)
與 Flax Linen 轉換相反,in_axes
參數和其他 API 會影響 nnx.Module
狀態的轉換方式。
此外,Flax NNX 轉換可以用作方法裝飾器,因為 nnx.Module
方法只是將 Module
作為第一個參數的函數。這表示先前的範例可以重寫如下
class WeightStack(nnx.Module):
@nnx.vmap(in_axes=(0, 0), out_axes=0)
def __init__(self, seed: jax.Array):
self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3)))
self.bias = nnx.Param(jnp.zeros((3,)))
@nnx.vmap(in_axes=(0, 0), out_axes=1)
def __call__(self, x: jax.Array):
assert self.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
return x @ self.kernel + self.bias
weights = WeightStack(jnp.arange(10))
x = jax.random.normal(random.key(1), (10, 2))
y = weights(x)