為何選擇 Flax NNX?#

在 2020 年,Flax 團隊發布了 Flax Linen API,以支援 JAX 上的模型研究,重點在於擴展性和效能。從那時起,我們從使用者身上學到了很多。團隊引入了一些已被證明對使用者有益的概念,例如

Flax 團隊所做的選擇之一是使用函數式 (compact) 語義,透過參數的延遲初始化進行神經網路編程。這使得實作程式碼簡潔,並使 Flax Linen API 與 Haiku 對齊。

然而,這也意味著 Flax 中模組和變數的語義是非 Python 式的,而且常常令人驚訝。它還導致了實作複雜性,並模糊了對神經網路進行轉換 (transforms)的核心概念。

介紹 Flax NNX#

快轉到 2024 年,Flax 團隊開發了 Flax NNX - 試圖保留使 Flax Linen 對使用者有用的功能,同時引入一些新的原則。Flax NNX 背後的中心思想是在 JAX 中引入參考語義。以下是其主要功能

  • NNX 是 Python 式的:模組的正規 Python 語義,包括對可變性和共享參考的支援。

  • NNX 很簡單:Flax Linen 中許多複雜的 API 要么使用 Python 慣用語簡化,要么完全移除。

  • 更好的 JAX 整合:自定義 NNX 轉換採用與 JAX 轉換相同的 API。使用 NNX,更容易直接使用 JAX 轉換 (高階函數)

以下是一個簡單的 Flax NNX 程式範例,說明了上述許多要點

from flax import nnx
import optax


class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # Eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # Reference sharing.

@nnx.jit  # Automatic state management for JAX transforms.
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # in-place updates

  return loss

Flax NNX 對 Linen 的改進#

本文檔的其餘部分使用各種範例來演示 Flax NNX 如何改進 Flax Linen。

檢視#

第一個改進是 Flax NNX 模組是正規的 Python 物件。這意味著您可以輕鬆地建構和檢視 Module 物件。

另一方面,Flax Linen 模組不容易檢視和偵錯,因為它們是延遲的,這意味著某些屬性在建構時不可用,只能在執行時存取。

class Block(nn.Module):
  def setup(self):
    self.linear = nn.Dense(10)

block = Block()

try:
  block.linear  # AttributeError: "Block" object has no attribute "linear".
except AttributeError as e:
  pass





...
class Block(nnx.Module):
  def __init__(self, rngs):
    self.linear = nnx.Linear(5, 10, rngs=rngs)

block = Block(nnx.Rngs(0))


block.linear
# Linear(
#   kernel=Param(
#     value=Array(shape=(5, 10), dtype=float32)
#   ),
#   bias=Param(
#     value=Array(shape=(10,), dtype=float32)
#   ),
#   ...

請注意,在上面的 Flax NNX 範例中,沒有形狀推斷 - 輸入和輸出形狀都必須提供給 Linear nnx.Module。這是一個權衡,允許更明確和可預測的行為。

執行計算#

在 Flax Linen 中,所有頂級計算都必須透過 flax.linen.Module.initflax.linen.Module.apply 方法完成,並且參數或任何其他類型的狀態都作為單獨的結構處理。這在以下兩者之間產生了不對稱:1) 可以在 apply 內部執行,可以直接執行方法和其他 Module 物件的程式碼;以及 2) 在 apply 外部執行的程式碼,必須使用 apply 方法。

在 Flax NNX 中,沒有特殊的上下文,因為參數作為屬性保留,並且可以直接呼叫方法。這意味著您的 NNX 模組的 __init____call__ 方法與其他類別方法沒有區別,而 Flax Linen 模組的 setup()__call__ 方法是特殊的。

Encoder = lambda: nn.Dense(10)
Decoder = lambda: nn.Dense(2)

class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = Encoder()
    self.decoder = Decoder()

  def __call__(self, x) -> jax.Array:
    return self.decoder(self.encoder(x))

  def encode(self, x) -> jax.Array:
    return self.encoder(x)

x = jnp.ones((1, 2))
model = AutoEncoder()
params = model.init(random.key(0), x)['params']

y = model.apply({'params': params}, x)
z = model.apply({'params': params}, x, method='encode')
y = Decoder().apply({'params': params['decoder']}, z)
Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)

class AutoEncoder(nnx.Module):
  def __init__(self, rngs):
    self.encoder = Encoder(rngs)
    self.decoder = Decoder(rngs)

  def __call__(self, x) -> jax.Array:
    return self.decoder(self.encoder(x))

  def encode(self, x) -> jax.Array:
    return self.encoder(x)

x = jnp.ones((1, 2))
model = AutoEncoder(nnx.Rngs(0))


y = model(x)
z = model.encode(x)
y = model.decoder(z)

在 Flax Linen 中,直接呼叫子模組是不可能的,因為它們沒有初始化。因此,您必須做的是建構一個新的實例,然後提供一個適當的參數結構。

但在 Flax NNX 中,您可以直接呼叫子模組而不會有任何問題。

狀態處理#

Flax Linen 眾所周知的複雜領域之一是狀態處理。當您使用 Dropout 層、BatchNorm 層或兩者時,您突然必須處理新的狀態,並使用它來配置 flax.linen.Module.apply 方法。

在 Flax NNX 中,狀態保存在 nnx.Module 內部,並且是可變的,這意味著可以直接呼叫它。

class Block(nn.Module):
  train: bool

  def setup(self):
    self.linear = nn.Dense(10)
    self.bn = nn.BatchNorm(use_running_average=not self.train)
    self.dropout = nn.Dropout(0.1, deterministic=not self.train)

  def __call__(self, x):
    return nn.relu(self.dropout(self.bn(self.linear(x))))

x = jnp.ones((1, 5))
model = Block(train=True)
vs = model.init(random.key(0), x)
params, batch_stats = vs['params'], vs['batch_stats']

y, updates = model.apply(
  {'params': params, 'batch_stats': batch_stats},
  x,
  rngs={'dropout': random.key(1)},
  mutable=['batch_stats'],
)
batch_stats = updates['batch_stats']
class Block(nnx.Module):


  def __init__(self, rngs):
    self.linear = nnx.Linear(5, 10, rngs=rngs)
    self.bn = nnx.BatchNorm(10, rngs=rngs)
    self.dropout = nnx.Dropout(0.1, rngs=rngs)

  def __call__(self, x):
    return nnx.relu(self.dropout(self.bn(self.linear(x))))

x = jnp.ones((1, 5))
model = Block(nnx.Rngs(0))



y = model(x)





...

Flax NNX 狀態處理的主要優點是,當您添加新的有狀態層時,不必更改訓練程式碼。

此外,在 Flax NNX 中,處理狀態的層也很容易實作。以下是一個簡化的 BatchNorm 層版本,每次呼叫時都會更新平均值和變異數。

class BatchNorm(nnx.Module):
  def __init__(self, features: int, mu: float = 0.95):
    # Variables
    self.scale = nnx.Param(jax.numpy.ones((features,)))
    self.bias = nnx.Param(jax.numpy.zeros((features,)))
    self.mean = nnx.BatchStat(jax.numpy.zeros((features,)))
    self.var = nnx.BatchStat(jax.numpy.ones((features,)))
    self.mu = mu  # Static

def __call__(self, x):
  mean = jax.numpy.mean(x, axis=-1)
  var = jax.numpy.var(x, axis=-1)
  # ema updates
  self.mean.value = self.mu * self.mean + (1 - self.mu) * mean
  self.var.value = self.mu * self.var + (1 - self.mu) * var
  # normalize and scale
  x = (x - mean) / jax.numpy.sqrt(var + 1e-5)
  return x * self.scale + self.bias

模型手術#

在 Flax Linen 中,由於以下兩個原因,模型手術在歷史上一直具有挑戰性

  1. 由於延遲初始化,不能保證您可以將子 Module 替換為新的子 Module

  2. 參數結構與 flax.linen.Module 結構分離,這意味著您必須手動保持它們同步。

在 Flax NNX 中,您可以根據 Python 語義直接替換子模組。由於參數是 nnx.Module 結構的一部分,它們永遠不會不同步。以下是如何實作 LoRA 層的範例,然後使用它來替換現有模型中的 Linear 層。

class LoraLinear(nn.Module):
  linear: nn.Dense
  rank: int

  @nn.compact
  def __call__(self, x: jax.Array):
    A = self.param(random.normal, (x.shape[-1], self.rank))
    B = self.param(random.normal, (self.rank, self.linear.features))

    return self.linear(x) + x @ A @ B

try:
  model = Block(train=True)
  model.linear = LoraLinear(model.linear, rank=5) # <-- ERROR

  lora_params = model.linear.init(random.key(1), x)
  lora_params['linear'] = params['linear']
  params['linear'] = lora_params

except AttributeError as e:
  pass
class LoraParam(nnx.Param): pass

class LoraLinear(nnx.Module):
  def __init__(self, linear, rank, rngs):
    self.linear = linear
    self.A = LoraParam(random.normal(rngs(), (linear.in_features, rank)))
    self.B = LoraParam(random.normal(rngs(), (rank, linear.out_features)))

  def __call__(self, x: jax.Array):
    return self.linear(x) + x @ self.A @ self.B

rngs = nnx.Rngs(0)
model = Block(rngs)
model.linear = LoraLinear(model.linear, rank=5, rngs=rngs)






...

如上所示,在 Flax Linen 中,這種情況實際上不起作用,因為 linearModule 不可用。但是,其餘程式碼提供了如何手動更新 params 結構的想法。

在 Flax Linen 中執行任意模型手術並不容易,目前 intercept_methods API 是對方法進行通用修補的唯一方法。但是這個 API 並不是很符合人體工學。

在 Flax NNX 中,要進行通用模型手術,您可以直接使用 nnx.iter_graph,這比 Linen 簡單且容易得多。以下是一個範例,說明如何將模型中的所有 nnx.Linear 層替換為自定義的 LoraLinear NNX 層。

rngs = nnx.Rngs(0)
model = Block(rngs)

for path, module in nnx.iter_graph(model):
  if isinstance(module, nnx.Module):
    for name, value in vars(module).items():
      if isinstance(value, nnx.Linear):
        setattr(module, name, LoraLinear(value, rank=5, rngs=rngs))

轉換#

Flax Linen 轉換非常強大,因為它們可以對模型的狀態進行細粒度的控制。但是,Flax Linen 轉換也有缺點,例如

  1. 它們公開了不屬於 JAX 的其他 API,使得它們的行為令人困惑,有時與 JAX 的對應物不同。這也限制了您與 JAX 轉換互動並跟上 JAX API 變化的方式。

  2. 它們作用於具有非常特定簽名的函數,即

  • flax.linen.Module 必須是第一個參數。

  • 它們接受其他 Module 物件作為參數,但不作為回傳值。

  1. 它們只能在 flax.linen.Module.apply 內部使用。

另一方面,Flax NNX 轉換 旨在與它們對應的 JAX 轉換 等效,但有一個例外 - 它們可以用於 Flax NNX 模組。這表示 Flax 轉換

  1. 具有與 JAX 轉換相同的 API。

  2. 可以在任何參數上接受 Flax NNX 模組,並且 nnx.Module 物件可以從它/它們回傳。

  3. 可以在任何地方使用,包括訓練迴圈。

以下是一個使用 vmap 和 Flax NNX 的範例,它透過轉換 create_weights 函數來創建權重堆疊,該函數回傳一些 Weights,並透過轉換 vector_dot 函數將該權重堆疊個別應用於一批輸入,該函數將 Weights 作為第一個參數,並將一批輸入作為第二個參數。

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)

def create_weights(seed: jax.Array):
  return Weights(
    kernel=random.uniform(random.key(seed), (2, 3)),
    bias=jnp.zeros((3,)),
  )

def vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  return x @ weights.kernel + weights.bias

seeds = jnp.arange(10)
weights = nnx.vmap(create_weights, in_axes=0, out_axes=0)(seeds)

x = jax.random.normal(random.key(1), (10, 2))
y = nnx.vmap(vector_dot, in_axes=(0, 0), out_axes=1)(weights, x)

與 Flax Linen 轉換相反,in_axes 參數和其他 API 會影響 nnx.Module 狀態的轉換方式。

此外,Flax NNX 轉換可以用作方法裝飾器,因為 nnx.Module 方法只是將 Module 作為第一個參數的函數。這表示先前的範例可以重寫如下

class WeightStack(nnx.Module):
  @nnx.vmap(in_axes=(0, 0), out_axes=0)
  def __init__(self, seed: jax.Array):
    self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3)))
    self.bias = nnx.Param(jnp.zeros((3,)))

  @nnx.vmap(in_axes=(0, 0), out_axes=1)
  def __call__(self, x: jax.Array):
    assert self.kernel.ndim == 2, 'Batch dimensions not allowed'
    assert x.ndim == 1, 'Batch dimensions not allowed'
    return x @ self.kernel + self.bias

weights = WeightStack(jnp.arange(10))

x = jax.random.normal(random.key(1), (10, 2))
y = weights(x)