轉換#

一般來說,JAX 轉換(transforms)作用於 pytreesjax.Array,並遵守數值語義。這對於 Flax NNX 來說是一個挑戰,它將 nnx.Module 表示為遵循參考語義的常規 Python 物件。為了應對這個問題,Flax NNX 引入了自己的一組轉換,擴展了 JAX 轉換,允許 nnx.Module 和其他 Flax NNX 物件在轉換中傳入和傳出,同時保持參考語義。

如果您之前使用過 JAX 轉換,那麼 Flax NNX 轉換應該會感到非常熟悉。它們使用相同的 API,並且在僅使用 jax.Array 的 pytrees 時,其行為與 JAX 轉換類似。但是,當使用 Flax NNX 物件時,它們允許為這些物件保留 Python 的參考語義,包括

  • 保留轉換的輸入和輸出中多個物件之間的共享參考。

  • 將轉換內對物件所做的任何狀態變更,傳播到轉換外的物件。

  • 當多個輸入和輸出之間存在別名時,強制執行物件轉換方式的一致性。

import jax
from jax import numpy as jnp, random
from flax import nnx

在本指南中,nnx.vmap 作為一個案例研究,用來說明 Flax NNX 轉換的工作原理。但是,本文檔中概述的原則適用於所有轉換。

基本範例#

首先,讓我們看一個使用 nnx.vmap 來擴展逐元素 vector_dot 函數以處理批次輸入的簡單範例。我們將定義一個沒有方法的 Weights 模組來保存一些參數,這些權重將作為輸入傳遞給 vector_dot 函數以及一些資料。權重和資料都將在軸 0 上進行批次處理,我們將使用 nnx.vmapvector_dot 應用於每個批次元素,結果將在軸 1 上進行批次處理。

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
)
x = jax.random.normal(random.key(1), (10, 2))

def vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  return x @ weights.kernel + weights.bias

y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x)

print(f'{y.shape = }')
nnx.display(weights)
y.shape = (3, 10)

請注意,in_axes 自然地與 Weights 模組互動,將其視為 jax.Array 的 pytree。也允許使用前綴模式,因此 in_axes=(0, 0) 在這種情況下也適用。

物件也允許作為 Flax NNX 轉換的輸出,這對於轉換初始化器很有用。例如,您可以定義一個 create_weights 函數來建立一個單一 Weights nnx.Module,並使用 nnx.vmap 來建立一堆形狀與之前相同的 Weights

def create_weights(seed: jax.Array):
  return Weights(
    kernel=random.uniform(random.key(seed), (2, 3)),
    bias=jnp.zeros((3,)),
  )

seeds = jnp.arange(10)
weights = nnx.vmap(create_weights)(seeds)
nnx.display(weights)

轉換方法#

Python 中的方法只是將實例作為第一個參數的函數,這表示您可以裝飾來自 Module 和其他 Flax NNX 子類的方法。例如,我們可以重構前面範例中的 Weights,並使用 vmap 裝飾 __init__ 來執行 create_weights 的工作,並新增一個 __call__ 方法,並使用 @nnx.vmap 裝飾它來執行 vector_dot 的工作。

class WeightStack(nnx.Module):
  @nnx.vmap
  def __init__(self, seed: jax.Array):
    self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3)))
    self.bias = nnx.Param(jnp.zeros((3,)))

  @nnx.vmap(in_axes=0, out_axes=1)
  def __call__(self, x: jax.Array):
    assert self.kernel.ndim == 2, 'Batch dimensions not allowed'
    assert x.ndim == 1, 'Batch dimensions not allowed'
    return x @ self.kernel + self.bias

weights = WeightStack(jnp.arange(10))

x = jax.random.normal(random.key(1), (10, 2))
y = weights(x)

print(f'{y.shape = }')
nnx.display(weights)
y.shape = (3, 10)

本指南的其餘部分將重點介紹轉換個別函數。但請注意,所有範例都可以用這種方法風格來編寫。

狀態傳播#

到目前為止,我們的函數都是無狀態的。但是,Flax NNX 轉換的真正威力在於您擁有有狀態函數的時候,因為它們的主要功能之一是傳播狀態變更以保留參考語義。讓我們更新先前的範例,方法是將一個 count 屬性新增到 Weights,並在新 stateful_vector_dot 函數中遞增它。

class Count(nnx.Variable): pass

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.arange(10),
)
x = jax.random.normal(random.key(1), (10, 2))

def stateful_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  return x @ weights.kernel + weights.bias


y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x)

weights.count
Count(
  value=Array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32)
)

執行一次 stateful_vector_dot 後,您會驗證 count 屬性已正確更新。由於 Weights 已向量化,count 已初始化為 arange(10),並且它的所有元素在轉換中都遞增了 1。最重要的是,更新已傳播到轉換外的原始 Weights 物件。太棒了!

圖形更新傳播#

JAX 轉換將輸入視為 jax.Array 的 pytree,而 Flax NNX 將輸入視為 jax.Array 的 pytree 和 Python 參考,其中參考形成圖形。Flax NNX 的狀態傳播機制可以追蹤對物件的任意更新,只要它們是輸入的局部(不支援轉換內對全域的更新)。

這表示您可以根據需要修改圖形結構,包括更新現有屬性、新增/刪除屬性、交換屬性、在物件之間共享(新的)參考、在物件之間共享 nnx.Variable 等。天空才是極限!

以下範例示範在 nnx.vmap 內對 Weights 物件執行一些任意更新,並驗證這些更新是否已正確傳播到轉換外的原始 Weights 物件。

class Count(nnx.Variable): pass

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.arange(10),
)
x = jax.random.normal(random.key(1), (10, 2))

def crazy_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  y = x @ weights.kernel + weights.bias
  weights.some_property = ['a', 2, False] # add attribute
  del weights.bias # delete attribute
  weights.new_param = weights.kernel # share reference
  return y

y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x)

nnx.display(weights)

能力越大,責任越大。
-班叔

雖然這個功能非常強大,但必須謹慎使用,因為它可能會與 JAX 對某些轉換的基本假設發生衝突。例如,jit 希望輸入的結構是穩定的,以便快取編譯後的函數,因此在 nnx.jit-ed 函數內變更圖形結構會導致持續重新編譯和效能降低。另一方面,scan 只允許固定 carry 結構,因此新增/刪除宣告為 carry 的子狀態會導致錯誤。

轉換子狀態(提升類型)#

某些 JAX 轉換允許使用 pytree 前綴來指定應該如何轉換輸入/輸出的不同部分。Flax NNX 支援 pytree 結構的 pytree 前綴,但目前它沒有圖形物件前綴的概念。相反地,Flax NNX 引入了「提升類型」的概念,該概念允許指定應該如何轉換物件的不同子狀態。不同的轉換支援不同的提升類型,以下是目前每個 JAX 轉換支援的 Flax NNX 提升類型列表

提升類型

JAX 轉換

StateAxes

vmappmapscan

StateSharding

jitshard_map*

DiffState

gradvalue_and_gradcustom_vjp

注意:* 在撰寫本文檔的此版本時,Flax NNX shard_map 尚未實作。

若要在 nnx.vmap 中指定如何向量化物件的不同子狀態,Flax 團隊建立了一個 nnx.StateAxesStateAxes 透過 Flax NNX 篩選器將一組子狀態對應到其對應的軸,並且您可以將 nnx.StateAxes 傳遞到 in_axesout_axes,就像它/它們是 pytree 前綴一樣。

讓我們使用先前的 stateful_vector_dot 範例,並且僅向量化 nnx.Param 變數,並廣播 count 變數,如此一來,我們只會為所有批次元素保留單一計數。為了做到這一點,我們將定義一個 nnx.StateAxes,其帶有一個篩選器,會匹配 nnx.Param 變數並將它們映射到軸 0,並將所有 Count 變數映射到 None,並將此 nnx.StateAxes 傳遞給 Weights 物件的 in_axes

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.array(0),
)
x = jax.random.normal(random.key(1), (10, 2))


def stateful_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  return x @ weights.kernel + weights.bias

state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count
y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x)

weights.count
Count(
  value=Array(1, dtype=int32, weak_type=True)
)

在這裡,count 現在是一個純量,因為它沒有被向量化。此外,請注意 nnx.StateAxes 只能直接用於 Flax NNX 物件,並且不能作為物件 pytree 的前綴。

隨機狀態#

在 Flax NNX 中,隨機狀態只是一個常規狀態。這表示它會儲存在需要它的 nnx.Module 中,並且會被視為任何其他類型的狀態。這簡化了 Flax Linen 的機制,在 Flax Linen 中,隨機狀態是由一個獨立的機制處理。實際上,nnx.Module 只需要保留一個對 Rngs 物件的參照,該物件會在初始化期間傳遞給它們,並使用它來為每個隨機操作產生唯一的金鑰。就本指南而言,這表示隨機狀態可以像任何其他類型的狀態一樣進行轉換,但我們也需要了解狀態的佈局方式,以便我們能夠正確地轉換它。

假設您想稍微改變一下,並將相同的權重套用至批次中的所有元素。但是您也想為每個元素新增不同的隨機雜訊。

為了做到這一點,您會將一個 Rngs 屬性新增至 Weights,此屬性是從建構期間傳遞的 seed 金鑰引數建立。這個 seed 金鑰必須事先經過 split,以便您可以成功地對其進行向量化。基於教學原因,您會將 seed 金鑰指定給一個 noise「串流」並從中取樣。為了向量化 PRNG 狀態,您必須將 nnx.StateAxes 設定為將所有 RngStateRngs 中所有變數的基底類別)映射到軸 0,並將 nnx.ParamCount 映射到 None

class Weights(nnx.Module):
  def __init__(self, kernel, bias, count, seed):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)
    self.rngs = nnx.Rngs(noise=seed)

weights = Weights(
  kernel=random.uniform(random.key(0), (2, 3)),
  bias=jnp.zeros((3,)),
  count=jnp.array(0),
  seed=random.split(random.key(0), num=10),
)
x = random.normal(random.key(1), (10, 2))

def noisy_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  y = x @ weights.kernel + weights.bias
  return y + random.normal(weights.rngs.noise(), y.shape)

state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})
y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)
y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)

print(jnp.allclose(y1, y2))
nnx.display(weights)
False

由於 Rngs 的狀態會原地更新,並由 nnx.vmap 自動傳播,因此每次呼叫 noisy_vector_dot 時,我們都會得到不同的結果。

在上面的範例中,您會在建構期間手動分割隨機狀態。這沒有問題,因為它清楚地表達了意圖,但它也不讓您在 nnx.vmap 之外使用 Rngs,因為它的狀態永遠是分割的。為了解決這個問題,您可以傳遞一個未分割的 seed,並在 nnx.vmap 之前使用 nnx.split_rngs 裝飾器,以便在每次呼叫函式之前分割 RngState,然後「降低」它使其變得可用。

weights = Weights(
  kernel=random.uniform(random.key(0), (2, 3)),
  bias=jnp.zeros((3,)),
  count=jnp.array(0),
  seed=0,
)
x = random.normal(random.key(1), (10, 2))

state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})

@nnx.split_rngs(splits=10)
@nnx.vmap(in_axes=(state_axes, 0))
def noisy_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  y = x @ weights.kernel + weights.bias
  return y + random.normal(weights.rngs.noise(), y.shape)

y1 = noisy_vector_dot(weights, x)
y2 = noisy_vector_dot(weights, x)

print(jnp.allclose(y1, y2))
nnx.display(weights)
False

規則和限制#

在本節中,我們將涵蓋在轉換內部使用模組時所套用的一些規則和限制。

可變動模組不能透過閉包傳遞#

雖然 Python 允許將物件作為閉包傳遞給函式,但 Flax NNX 轉換通常不支援這種做法。原因是,由於模組是可變動的,因此很容易將追蹤器擷取到轉換之外建立的模組中,這是 JAX 中的無聲錯誤。為了避免這種情況,Flax NNX 會檢查正在變動的模組和變數是否作為引數傳遞給已轉換的函式。

例如,如果我們有一個有狀態的模組,例如 Counter,它會在每次呼叫時遞增計數器,而我們嘗試將其作為閉包傳遞給以 nnx.jit 修飾的函式,我們就會洩漏追蹤器。但是,Flax NNX 會改為引發錯誤來防止這種情況

class Counter(nnx.Module):
  def __init__(self):
    self.count = nnx.Param(jnp.array(0))

  def increment(self):
    self.count += jnp.array(1)

counter = Counter()

@nnx.jit
def f(x):
  counter.increment()
  return 2 * x

try:
  y = f(3)
except Exception as e:
  print(e)

為了解決這個問題,請將所有模組作為引數傳遞給正在轉換的函式。在這種情況下,f 應該接受 counter 作為引數。

一致的別名#

在轉換中允許參考語義的主要問題在於,參考可以在輸入和輸出之間共用。如果沒有妥善處理,這可能會產生問題,因為它會導致不適當或不一致的行為。在下面的範例中,您有一個單一的 Weights 模組 m,其參考會多次出現在 arg1arg2 中。這裡的問題在於,您也指定想要在軸 0 中向量化 arg1,並在軸 1 中向量化 arg2。這在 JAX 中沒有問題,因為 pytree 具有參考透明度。但是這在 Flax NNX 中會產生問題,因為您嘗試以兩種不同的方式向量化 m。Flax NNX 將透過引發錯誤來強制執行一致性。

class Weights(nnx.Module):
  def __init__(self, array: jax.Array):
    self.param = nnx.Param(array)

m = Weights(jnp.arange(10))
arg1 = {'a': {'b': m}, 'c': m}
arg2 = [(m, m), m]

@nnx.vmap(in_axes=(0, 1))
def f(arg1, arg2):
  ...

try:
  f(arg1, arg2)
except ValueError as e:
  print(e)
Inconsistent aliasing detected. The following nodes have different prefixes:
Node: <class 'flax.nnx.variablelib.Param'>
  param: 0
  param: 0
  param: 1

輸入和輸出之間也可能發生不一致的別名。在下一個範例中,您有一個簡單的函式,會接受並立即傳回 arg1。但是,arg1 在輸入的軸 0 上進行向量化,並在輸出的軸 1 上進行向量化。如預期,這會產生問題,而 Flax NNX 會引發錯誤。

@nnx.vmap(in_axes=0, out_axes=1)
def f(arg1):
  return arg1

try:
  f(arg1)
except ValueError as e:
  print(e)
Inconsistent aliasing detected. The following nodes have different prefixes:
Node: <class 'flax.nnx.variablelib.Param'>
  param: 0
  param: 0
  param: 1

軸中繼資料#

Flax NNX Variable 可以保留任意的中繼資料,這些中繼資料可以透過將其作為關鍵字引數傳遞給其建構函式來新增。這通常用於儲存 sharding 資訊,如 nnx.spmd API(例如 nnx.get_partition_specnnx.get_named_sharding)所使用。

但是,當涉及轉換時,通常很重要的一點是讓這些與軸相關的資訊與軸的實際狀態保持同步。例如,如果您在軸 1 上向量化變數,則在 vmapscan 內部時,應該移除位置 1 上的 sharding 資訊,以反映軸暫時移除的事實。

為了實現這一點,Flax NNX 轉換提供了一個非標準的 transform_metadata 字典引數。當 nnx.PARTITION_NAME 金鑰存在時,sharding 中繼資料會依照 in_axesout_axes 的指定來更新。

讓我們看看一個實際運作的範例

class Weights(nnx.Module):
  def __init__(self, array: jax.Array, sharding: tuple[str | None, ...]):
    self.param = nnx.Param(array, sharding=sharding)

m = Weights(jnp.ones((3, 4, 5)), sharding=('a', 'b', None))

@nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'})
def f(m: Weights):
  print(f'Inner {m.param.shape = }')
  print(f'Inner {m.param.sharding = }')

f(m)
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')
Inner m.param.shape = (3, 5)
Inner m.param.sharding = ('a', None)
Outter m.param.shape = (3, 4, 5)
Outter m.param.sharding = ('a', 'b', None)

在這裡,您將 sharding 中繼資料新增至 nnx.Param 變數,並使用 transform_metadata 來更新 sharding 中繼資料,以反映軸的變更。具體而言,您可以看到,當位於 nnx.vmap 內部時,第一個軸 b 已從 sharding 中繼資料中移除,然後在 nnx.vmap 外部時再將其新增回來。

您可以驗證,當在轉換內部建立 nnx.Module 時,這也會正常運作 - 新的 sharding 軸將會新增至轉換外部的 nnx.Module nnx.Variable,這會符合已轉換 nnx.Variable 的軸。

@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})
def init_vmap():
  return Weights(jnp.ones((3, 5)), sharding=('a', None))

m = init_vmap()
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')
Outter m.param.shape = (3, 4, 5)
Outter m.param.sharding = ('a', 'b', None)