轉換#
一般來說,JAX 轉換(transforms)作用於 pytrees 的 jax.Array
,並遵守數值語義。這對於 Flax NNX 來說是一個挑戰,它將 nnx.Module
表示為遵循參考語義的常規 Python 物件。為了應對這個問題,Flax NNX 引入了自己的一組轉換,擴展了 JAX 轉換,允許 nnx.Module
和其他 Flax NNX 物件在轉換中傳入和傳出,同時保持參考語義。
如果您之前使用過 JAX 轉換,那麼 Flax NNX 轉換應該會感到非常熟悉。它們使用相同的 API,並且在僅使用 jax.Array
的 pytrees 時,其行為與 JAX 轉換類似。但是,當使用 Flax NNX 物件時,它們允許為這些物件保留 Python 的參考語義,包括
保留轉換的輸入和輸出中多個物件之間的共享參考。
將轉換內對物件所做的任何狀態變更,傳播到轉換外的物件。
當多個輸入和輸出之間存在別名時,強制執行物件轉換方式的一致性。
import jax
from jax import numpy as jnp, random
from flax import nnx
在本指南中,nnx.vmap
作為一個案例研究,用來說明 Flax NNX 轉換的工作原理。但是,本文檔中概述的原則適用於所有轉換。
基本範例#
首先,讓我們看一個使用 nnx.vmap
來擴展逐元素 vector_dot
函數以處理批次輸入的簡單範例。我們將定義一個沒有方法的 Weights
模組來保存一些參數,這些權重將作為輸入傳遞給 vector_dot
函數以及一些資料。權重和資料都將在軸 0
上進行批次處理,我們將使用 nnx.vmap
將 vector_dot
應用於每個批次元素,結果將在軸 1
上進行批次處理。
class Weights(nnx.Module):
def __init__(self, kernel: jax.Array, bias: jax.Array):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
)
x = jax.random.normal(random.key(1), (10, 2))
def vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
return x @ weights.kernel + weights.bias
y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x)
print(f'{y.shape = }')
nnx.display(weights)
y.shape = (3, 10)
請注意,in_axes
自然地與 Weights
模組互動,將其視為 jax.Array
的 pytree。也允許使用前綴模式,因此 in_axes=(0, 0)
在這種情況下也適用。
物件也允許作為 Flax NNX 轉換的輸出,這對於轉換初始化器很有用。例如,您可以定義一個 create_weights
函數來建立一個單一 Weights
nnx.Module
,並使用 nnx.vmap
來建立一堆形狀與之前相同的 Weights
。
def create_weights(seed: jax.Array):
return Weights(
kernel=random.uniform(random.key(seed), (2, 3)),
bias=jnp.zeros((3,)),
)
seeds = jnp.arange(10)
weights = nnx.vmap(create_weights)(seeds)
nnx.display(weights)
轉換方法#
Python 中的方法只是將實例作為第一個參數的函數,這表示您可以裝飾來自 Module
和其他 Flax NNX 子類的方法。例如,我們可以重構前面範例中的 Weights
,並使用 vmap
裝飾 __init__
來執行 create_weights
的工作,並新增一個 __call__
方法,並使用 @nnx.vmap
裝飾它來執行 vector_dot
的工作。
class WeightStack(nnx.Module):
@nnx.vmap
def __init__(self, seed: jax.Array):
self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3)))
self.bias = nnx.Param(jnp.zeros((3,)))
@nnx.vmap(in_axes=0, out_axes=1)
def __call__(self, x: jax.Array):
assert self.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
return x @ self.kernel + self.bias
weights = WeightStack(jnp.arange(10))
x = jax.random.normal(random.key(1), (10, 2))
y = weights(x)
print(f'{y.shape = }')
nnx.display(weights)
y.shape = (3, 10)
本指南的其餘部分將重點介紹轉換個別函數。但請注意,所有範例都可以用這種方法風格來編寫。
狀態傳播#
到目前為止,我們的函數都是無狀態的。但是,Flax NNX 轉換的真正威力在於您擁有有狀態函數的時候,因為它們的主要功能之一是傳播狀態變更以保留參考語義。讓我們更新先前的範例,方法是將一個 count
屬性新增到 Weights
,並在新 stateful_vector_dot
函數中遞增它。
class Count(nnx.Variable): pass
class Weights(nnx.Module):
def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self.count = Count(count)
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
count=jnp.arange(10),
)
x = jax.random.normal(random.key(1), (10, 2))
def stateful_vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
weights.count += 1
return x @ weights.kernel + weights.bias
y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x)
weights.count
Count(
value=Array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int32)
)
執行一次 stateful_vector_dot
後,您會驗證 count
屬性已正確更新。由於 Weights
已向量化,count
已初始化為 arange(10)
,並且它的所有元素在轉換中都遞增了 1
。最重要的是,更新已傳播到轉換外的原始 Weights
物件。太棒了!
圖形更新傳播#
JAX 轉換將輸入視為 jax.Array
的 pytree,而 Flax NNX 將輸入視為 jax.Array
的 pytree 和 Python 參考,其中參考形成圖形。Flax NNX 的狀態傳播機制可以追蹤對物件的任意更新,只要它們是輸入的局部(不支援轉換內對全域的更新)。
這表示您可以根據需要修改圖形結構,包括更新現有屬性、新增/刪除屬性、交換屬性、在物件之間共享(新的)參考、在物件之間共享 nnx.Variable
等。天空才是極限!
以下範例示範在 nnx.vmap
內對 Weights
物件執行一些任意更新,並驗證這些更新是否已正確傳播到轉換外的原始 Weights
物件。
class Count(nnx.Variable): pass
class Weights(nnx.Module):
def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self.count = Count(count)
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
count=jnp.arange(10),
)
x = jax.random.normal(random.key(1), (10, 2))
def crazy_vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
weights.count += 1
y = x @ weights.kernel + weights.bias
weights.some_property = ['a', 2, False] # add attribute
del weights.bias # delete attribute
weights.new_param = weights.kernel # share reference
return y
y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x)
nnx.display(weights)
能力越大,責任越大。
-班叔
雖然這個功能非常強大,但必須謹慎使用,因為它可能會與 JAX 對某些轉換的基本假設發生衝突。例如,jit
希望輸入的結構是穩定的,以便快取編譯後的函數,因此在 nnx.jit
-ed 函數內變更圖形結構會導致持續重新編譯和效能降低。另一方面,scan
只允許固定 carry
結構,因此新增/刪除宣告為 carry 的子狀態會導致錯誤。
轉換子狀態(提升類型)#
某些 JAX 轉換允許使用 pytree 前綴來指定應該如何轉換輸入/輸出的不同部分。Flax NNX 支援 pytree 結構的 pytree 前綴,但目前它沒有圖形物件前綴的概念。相反地,Flax NNX 引入了「提升類型」的概念,該概念允許指定應該如何轉換物件的不同子狀態。不同的轉換支援不同的提升類型,以下是目前每個 JAX 轉換支援的 Flax NNX 提升類型列表
提升類型 |
JAX 轉換 |
---|---|
|
|
|
|
|
|
注意:* 在撰寫本文檔的此版本時,Flax NNX
shard_map
尚未實作。
若要在 nnx.vmap
中指定如何向量化物件的不同子狀態,Flax 團隊建立了一個 nnx.StateAxes
。StateAxes
透過 Flax NNX 篩選器將一組子狀態對應到其對應的軸,並且您可以將 nnx.StateAxes
傳遞到 in_axes
和 out_axes
,就像它/它們是 pytree 前綴一樣。
讓我們使用先前的 stateful_vector_dot
範例,並且僅向量化 nnx.Param
變數,並廣播 count
變數,如此一來,我們只會為所有批次元素保留單一計數。為了做到這一點,我們將定義一個 nnx.StateAxes
,其帶有一個篩選器,會匹配 nnx.Param
變數並將它們映射到軸 0
,並將所有 Count
變數映射到 None
,並將此 nnx.StateAxes
傳遞給 Weights
物件的 in_axes
。
class Weights(nnx.Module):
def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self.count = Count(count)
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
count=jnp.array(0),
)
x = jax.random.normal(random.key(1), (10, 2))
def stateful_vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
weights.count += 1
return x @ weights.kernel + weights.bias
state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count
y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x)
weights.count
Count(
value=Array(1, dtype=int32, weak_type=True)
)
在這裡,count
現在是一個純量,因為它沒有被向量化。此外,請注意 nnx.StateAxes
只能直接用於 Flax NNX 物件,並且不能作為物件 pytree 的前綴。
隨機狀態#
在 Flax NNX 中,隨機狀態只是一個常規狀態。這表示它會儲存在需要它的 nnx.Module
中,並且會被視為任何其他類型的狀態。這簡化了 Flax Linen 的機制,在 Flax Linen 中,隨機狀態是由一個獨立的機制處理。實際上,nnx.Module
只需要保留一個對 Rngs
物件的參照,該物件會在初始化期間傳遞給它們,並使用它來為每個隨機操作產生唯一的金鑰。就本指南而言,這表示隨機狀態可以像任何其他類型的狀態一樣進行轉換,但我們也需要了解狀態的佈局方式,以便我們能夠正確地轉換它。
假設您想稍微改變一下,並將相同的權重套用至批次中的所有元素。但是您也想為每個元素新增不同的隨機雜訊。
為了做到這一點,您會將一個 Rngs
屬性新增至 Weights
,此屬性是從建構期間傳遞的 seed
金鑰引數建立。這個 seed 金鑰必須事先經過 split
,以便您可以成功地對其進行向量化。基於教學原因,您會將 seed 金鑰指定給一個 noise
「串流」並從中取樣。為了向量化 PRNG 狀態,您必須將 nnx.StateAxes
設定為將所有 RngState
(Rngs
中所有變數的基底類別)映射到軸 0
,並將 nnx.Param
和 Count
映射到 None
。
class Weights(nnx.Module):
def __init__(self, kernel, bias, count, seed):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self.count = Count(count)
self.rngs = nnx.Rngs(noise=seed)
weights = Weights(
kernel=random.uniform(random.key(0), (2, 3)),
bias=jnp.zeros((3,)),
count=jnp.array(0),
seed=random.split(random.key(0), num=10),
)
x = random.normal(random.key(1), (10, 2))
def noisy_vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
weights.count += 1
y = x @ weights.kernel + weights.bias
return y + random.normal(weights.rngs.noise(), y.shape)
state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})
y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)
y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)
print(jnp.allclose(y1, y2))
nnx.display(weights)
False
由於 Rngs
的狀態會原地更新,並由 nnx.vmap
自動傳播,因此每次呼叫 noisy_vector_dot
時,我們都會得到不同的結果。
在上面的範例中,您會在建構期間手動分割隨機狀態。這沒有問題,因為它清楚地表達了意圖,但它也不讓您在 nnx.vmap
之外使用 Rngs
,因為它的狀態永遠是分割的。為了解決這個問題,您可以傳遞一個未分割的 seed,並在 nnx.vmap
之前使用 nnx.split_rngs
裝飾器,以便在每次呼叫函式之前分割 RngState
,然後「降低」它使其變得可用。
weights = Weights(
kernel=random.uniform(random.key(0), (2, 3)),
bias=jnp.zeros((3,)),
count=jnp.array(0),
seed=0,
)
x = random.normal(random.key(1), (10, 2))
state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})
@nnx.split_rngs(splits=10)
@nnx.vmap(in_axes=(state_axes, 0))
def noisy_vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
weights.count += 1
y = x @ weights.kernel + weights.bias
return y + random.normal(weights.rngs.noise(), y.shape)
y1 = noisy_vector_dot(weights, x)
y2 = noisy_vector_dot(weights, x)
print(jnp.allclose(y1, y2))
nnx.display(weights)
False
規則和限制#
在本節中,我們將涵蓋在轉換內部使用模組時所套用的一些規則和限制。
可變動模組不能透過閉包傳遞#
雖然 Python 允許將物件作為閉包傳遞給函式,但 Flax NNX 轉換通常不支援這種做法。原因是,由於模組是可變動的,因此很容易將追蹤器擷取到轉換之外建立的模組中,這是 JAX 中的無聲錯誤。為了避免這種情況,Flax NNX 會檢查正在變動的模組和變數是否作為引數傳遞給已轉換的函式。
例如,如果我們有一個有狀態的模組,例如 Counter
,它會在每次呼叫時遞增計數器,而我們嘗試將其作為閉包傳遞給以 nnx.jit
修飾的函式,我們就會洩漏追蹤器。但是,Flax NNX 會改為引發錯誤來防止這種情況
class Counter(nnx.Module):
def __init__(self):
self.count = nnx.Param(jnp.array(0))
def increment(self):
self.count += jnp.array(1)
counter = Counter()
@nnx.jit
def f(x):
counter.increment()
return 2 * x
try:
y = f(3)
except Exception as e:
print(e)
為了解決這個問題,請將所有模組作為引數傳遞給正在轉換的函式。在這種情況下,f
應該接受 counter
作為引數。
一致的別名#
在轉換中允許參考語義的主要問題在於,參考可以在輸入和輸出之間共用。如果沒有妥善處理,這可能會產生問題,因為它會導致不適當或不一致的行為。在下面的範例中,您有一個單一的 Weights
模組 m
,其參考會多次出現在 arg1
和 arg2
中。這裡的問題在於,您也指定想要在軸 0
中向量化 arg1
,並在軸 1
中向量化 arg2
。這在 JAX 中沒有問題,因為 pytree 具有參考透明度。但是這在 Flax NNX 中會產生問題,因為您嘗試以兩種不同的方式向量化 m
。Flax NNX 將透過引發錯誤來強制執行一致性。
class Weights(nnx.Module):
def __init__(self, array: jax.Array):
self.param = nnx.Param(array)
m = Weights(jnp.arange(10))
arg1 = {'a': {'b': m}, 'c': m}
arg2 = [(m, m), m]
@nnx.vmap(in_axes=(0, 1))
def f(arg1, arg2):
...
try:
f(arg1, arg2)
except ValueError as e:
print(e)
Inconsistent aliasing detected. The following nodes have different prefixes:
Node: <class 'flax.nnx.variablelib.Param'>
param: 0
param: 0
param: 1
輸入和輸出之間也可能發生不一致的別名。在下一個範例中,您有一個簡單的函式,會接受並立即傳回 arg1
。但是,arg1
在輸入的軸 0
上進行向量化,並在輸出的軸 1
上進行向量化。如預期,這會產生問題,而 Flax NNX 會引發錯誤。
@nnx.vmap(in_axes=0, out_axes=1)
def f(arg1):
return arg1
try:
f(arg1)
except ValueError as e:
print(e)
Inconsistent aliasing detected. The following nodes have different prefixes:
Node: <class 'flax.nnx.variablelib.Param'>
param: 0
param: 0
param: 1
軸中繼資料#
Flax NNX Variable
可以保留任意的中繼資料,這些中繼資料可以透過將其作為關鍵字引數傳遞給其建構函式來新增。這通常用於儲存 sharding
資訊,如 nnx.spmd
API(例如 nnx.get_partition_spec
和 nnx.get_named_sharding
)所使用。
但是,當涉及轉換時,通常很重要的一點是讓這些與軸相關的資訊與軸的實際狀態保持同步。例如,如果您在軸 1
上向量化變數,則在 vmap
或 scan
內部時,應該移除位置 1
上的 sharding
資訊,以反映軸暫時移除的事實。
為了實現這一點,Flax NNX 轉換提供了一個非標準的 transform_metadata
字典引數。當 nnx.PARTITION_NAME
金鑰存在時,sharding
中繼資料會依照 in_axes
和 out_axes
的指定來更新。
讓我們看看一個實際運作的範例
class Weights(nnx.Module):
def __init__(self, array: jax.Array, sharding: tuple[str | None, ...]):
self.param = nnx.Param(array, sharding=sharding)
m = Weights(jnp.ones((3, 4, 5)), sharding=('a', 'b', None))
@nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'})
def f(m: Weights):
print(f'Inner {m.param.shape = }')
print(f'Inner {m.param.sharding = }')
f(m)
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')
Inner m.param.shape = (3, 5)
Inner m.param.sharding = ('a', None)
Outter m.param.shape = (3, 4, 5)
Outter m.param.sharding = ('a', 'b', None)
在這裡,您將 sharding
中繼資料新增至 nnx.Param
變數,並使用 transform_metadata
來更新 sharding
中繼資料,以反映軸的變更。具體而言,您可以看到,當位於 nnx.vmap
內部時,第一個軸 b
已從 sharding
中繼資料中移除,然後在 nnx.vmap
外部時再將其新增回來。
您可以驗證,當在轉換內部建立 nnx.Module
時,這也會正常運作 - 新的 sharding
軸將會新增至轉換外部的 nnx.Module
nnx.Variable
,這會符合已轉換 nnx.Variable
的軸。
@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})
def init_vmap():
return Weights(jnp.ones((3, 5)), sharding=('a', None))
m = init_vmap()
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')
Outter m.param.shape = (3, 4, 5)
Outter m.param.sharding = ('a', 'b', None)