同時使用 Flax NNX 和 Linen#
本指南適用於想要將其程式碼庫混合使用 Flax Linen 和 Flax NNX Module
的現有 Flax 使用者,這得益於 flax.nnx.bridge
API。
如果您有以下情況,這將很有幫助
想要逐步地將程式碼庫遷移到 NNX,一次遷移一個模組;
有已經移至 NNX 但您尚未遷移的外部依賴,或是在您遷移到 NNX 時仍然在 Linen 中。
我們希望這能讓您以自己的步調遷移並嘗試 NNX,並充分利用兩者的優勢。我們還將討論如何解決互操作兩個 API 的注意事項,因為它們在某些方面存在根本的不同。
注意:
本指南是關於膠合 Linen 和 NNX 模組。要將現有的 Linen 模組遷移到 NNX,請查看從 Flax Linen 遷移到 Flax NNX指南。
而且所有內建的 Linen 層都應該有等效的 NNX 版本!請查看內建 NNX 層列表。
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from flax import nnx
from flax import linen as nn
from flax.nnx import bridge
import jax
from jax import numpy as jnp
from jax.experimental import mesh_utils
from typing import *
子模組就是您所需要的#
Flax 模型始終是模組樹 - 舊的 Linen 模組 (flax.linen.Module
,通常寫成 nn.Module
) 或 NNX 模組 (nnx.Module
)。
nnx.bridge
包裝器以兩種方式將這兩種型別膠合在一起
nnx.bridge.ToNNX
:將 Linen 模組轉換為 NNX,使其可以成為另一個 NNX 模組的子模組,或獨立存在於 NNX 樣式的訓練迴圈中進行訓練。nnx.bridge.ToLinen
:反之亦然,將 NNX 模組轉換為 Linen。
這表示您可以採用自上而下或自下而上的行為:將整個 Linen 模組轉換為 NNX,然後逐漸向下移動,或將所有較低層級的模組轉換為 NNX,然後向上移動。
基礎知識#
Linen 和 NNX 模組之間存在兩個基本差異
無狀態 vs. 有狀態:Linen 模組實例是無狀態的:變數從純函數式
.init()
呼叫返回並單獨管理。然而,NNX 模組將其變數作為實例屬性擁有。惰性 vs. 急切:Linen 模組僅在實際看到其輸入時才分配空間來建立變數。然而,NNX 模組實例在實例化時會立即建立其變數,而無需看到範例輸入。
記住這一點,讓我們看看 nnx.bridge
包裝器如何解決這些差異。
Linen -> NNX#
由於 Linen 模組可能需要輸入才能建立變數,因此我們在從 Linen 轉換而來的 NNX 模組中半正式地支援了惰性初始化。Linen 變數會在您提供範例輸入時建立。
對於您來說,這是呼叫 nnx.bridge.lazy_init()
,您在 Linen 程式碼中呼叫 module.init()
的地方。
(注意:您可以對任何 NNX 模組呼叫 nnx.display
以檢查其所有變數和狀態。)
class LinenDot(nn.Module):
out_dim: int
w_init: Callable[..., Any] = nn.initializers.lecun_normal()
@nn.compact
def __call__(self, x):
# Linen might need the input shape to create the weight!
w = self.param('w', self.w_init, (x.shape[-1], self.out_dim))
return x @ w
x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.ToNNX(LinenDot(64),
rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen
bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen
y = model(x) # => `y = model.apply(var, x)` in Linen
nnx.display(model)
# In-place swap your weight array and the model still works!
model.w.value = jax.random.normal(jax.random.key(1), (32, 64))
assert not jnp.allclose(y, model(x))
即使最上層模組是純 NNX 模組,nnx.bridge.lazy_init
也有效,因此您可以隨意執行子模組化
class NNXOuter(nnx.Module):
def __init__(self, out_dim: int, rngs: nnx.Rngs):
self.dot = nnx.bridge.ToNNX(LinenDot(out_dim), rngs=rngs)
self.b = nnx.Param(jax.random.uniform(rngs.params(), (1, out_dim,)))
def __call__(self, x):
return self.dot(x) + self.b
x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit into one line
nnx.display(model)
Linen 權重已轉換為典型的 NNX 變數,它是實際 JAX 陣列值中的一個薄包裝器。在此,w
是一個 nnx.Param
,因為它屬於 LinenDot
模組的 params
集合。
我們將在NNX 變數 <-> Linen 集合部分中詳細討論不同的集合和型別。現在,只需知道它們會像原生變數一樣轉換為 NNX 變數。
assert isinstance(model.dot.w, nnx.Param)
assert isinstance(model.dot.w.value, jax.Array)
如果您建立此模型而不使用 nnx.bridge.lazy_init
,則在外部定義的 NNX 變數會像往常一樣初始化,但 Linen 部分 (包裝在 ToNNX
中) 將不會初始化。
partial_model = NNXOuter(64, rngs=nnx.Rngs(0))
nnx.display(partial_model)
full_model = bridge.lazy_init(partial_model, x)
nnx.display(full_model)
NNX -> Linen#
要將 NNX 模組轉換為 Linen,您應該將建立引數轉發到 bridge.ToLinen
,並讓它處理實際的建立過程。
這是因為 NNX 模組實例會在建立時急切地初始化其所有變數,這會消耗記憶體和計算資源。另一方面,Linen 模組是無狀態的,典型的 init
和 apply
過程涉及多次建立它們。因此,bridge.to_linen
會處理實際的模組建立,並確保不會分配兩次記憶體。
class NNXDot(nnx.Module):
def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):
self.w = nnx.Param(nnx.initializers.lecun_normal()(
rngs.params(), (in_dim, out_dim)))
def __call__(self, x: jax.Array):
return x @ self.w
x = jax.random.normal(jax.random.key(42), (4, 32))
# Pass in the arguments, not an actual module
model = bridge.to_linen(NNXDot, 32, out_dim=64)
variables = model.init(jax.random.key(0), x)
y = model.apply(variables, x)
print(list(variables.keys()))
print(variables['params']['w'].shape) # => (32, 64)
print(y.shape) # => (4, 64)
['nnx', 'params']
(32, 64)
(4, 64)
請注意,ToLinen
模組需要追蹤額外的變數集合 - nnx
- 以取得基礎 NNX 模組的靜態元數據。
# This new field stores the static data that defines the underlying `NNXDot`
print(type(variables['nnx']['graphdef'])) # => `nnx.graph.NodeDef`
<class 'flax.nnx.graph.NodeDef'>
bridge.to_linen
實際上是 Linen 模組 bridge.ToLinen
周圍的便利包裝器。大多數情況下,您根本不需要直接使用 ToLinen
,除非您正在使用 ToLinen
的內建引數之一。例如,如果您的 NNX 模組不想使用 RNG 處理進行初始化
class NNXAddConstant(nnx.Module):
def __init__(self):
self.constant = nnx.Variable(jnp.array(1))
def __call__(self, x):
return x + self.constant
# You have to use `skip_rng=True` because this module's `__init__` don't
# take `rng` as argument
model = bridge.ToLinen(NNXAddConstant, skip_rng=True)
y, var = model.init_with_output(jax.random.key(0), x)
與 ToNNX
類似,您可以使用 ToLinen
來建立另一個 Linen 模組的子模組。
class LinenOuter(nn.Module):
out_dim: int
@nn.compact
def __call__(self, x):
dot = bridge.to_linen(NNXDot, x.shape[-1], self.out_dim)
b = self.param('b', nn.initializers.lecun_normal(), (1, self.out_dim))
return dot(x) + b
x = jax.random.normal(jax.random.key(42), (4, 32))
model = LinenOuter(out_dim=64)
y, variables = model.init_with_output(jax.random.key(0), x)
w, b = variables['params']['ToLinen_0']['w'], variables['params']['b']
print(w.shape, b.shape, y.shape)
(32, 64) (1, 64) (4, 64)
處理 RNG 金鑰#
所有 Flax 模組 (Linen 或 NNX) 都會自動處理變數建立和隨機層(如 dropout)的 RNG 金鑰。但是,RNG 金鑰分割的特定邏輯不同,因此即使您傳入相同的金鑰,也不能在 Linen 和 NNX 模組之間產生相同的參數。
另一個不同之處在於 NNX 模組是有狀態的,因此它們可以在自身內部追蹤和更新 RNG 金鑰。
Linen 到 NNX#
如果將 Linen 模組轉換為 NNX,您將享有有狀態的好處,並且無需在每次模組呼叫時傳入額外的 RNG 金鑰。您可以始終使用 nnx.reseed
來重設其中的 RNG 狀態。
x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.ToNNX(nn.Dropout(rate=0.5, deterministic=False), rngs=nnx.Rngs(dropout=0))
# We don't really need to call lazy_init because no extra params were created here,
# but it's a good practice to always add this line.
bridge.lazy_init(model, x)
y1, y2 = model(x), model(x)
assert not jnp.allclose(y1, y2) # Two runs yield different outputs!
# Reset the dropout RNG seed, so that next model run will be the same as the first.
nnx.reseed(model, dropout=0)
assert jnp.allclose(y1, model(x))
NNX 到 Linen#
如果將 NNX 模組轉換為 Linen,則基礎 NNX 模組的 RNG 狀態仍然是最上層 variables
的一部分。另一方面,Linen apply()
呼叫在每次呼叫時接受不同的 RNG 金鑰,這會重設內部 Linen 環境並允許產生不同的隨機資料。
現在,這實際上取決於您的基礎 NNX 模組是從其 RNG 狀態還是從傳入的引數產生新的隨機資料。幸運的是,nnx.Dropout
同時支援這兩種方式 - 如果有傳入的金鑰,則使用傳入的金鑰,否則使用其自身的 RNG 狀態。
這為您提供了兩種處理 RNG 金鑰的樣式選項
NNX 樣式(建議):讓基礎 NNX 狀態管理 RNG 金鑰,無需在
apply()
中傳入額外的金鑰。這表示每次 apply 呼叫都需要更多的程式碼行來變異variables
,但是一旦您的整個模型不再需要ToLinen
,事情就會變得更容易。Linen 樣式:只需為每次
apply()
呼叫傳入不同的 RNG 金鑰。
x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.to_linen(nnx.Dropout, rate=0.5)
variables = model.init({'dropout': jax.random.key(0)}, x)
# The NNX RNG state was stored inside `variables`
print('The RNG key in state:', variables['RngKey']['rngs']['dropout']['key'].value)
print('Number of key splits:', variables['RngCount']['rngs']['dropout']['count'].value)
# NNX style: Must set `RngCount` as mutable and update the variables after every `apply`
y1, updates = model.apply(variables, x, mutable=['RngCount'])
variables |= updates
y2, updates = model.apply(variables, x, mutable=['RngCount'])
variables |= updates
print('Number of key splits after y2:', variables['RngCount']['rngs']['dropout']['count'].value)
assert not jnp.allclose(y1, y2) # Every call yields different output!
# Linen style: Just pass different RNG keys for every `apply()` call.
y3 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})
y4 = model.apply(variables, x, rngs={'dropout': jax.random.key(2)})
assert not jnp.allclose(y3, y4) # Every call yields different output!
y5 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})
assert jnp.allclose(y3, y5) # When you use same top-level RNG, outputs are same
The RNG key in state: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Number of key splits: 0
Number of key splits after y2: 2
NNX 變數類型 vs. Linen 集合#
當您想要將某些變數分組為一個類別時,在 Linen 中您會使用不同的集合;在 NNX 中,由於所有變數都應該是最上層的 Python 屬性,因此您會使用不同的變數類型。
因此,當混合使用 Linen 和 NNX 模組時,Flax 必須知道 Linen 集合和 NNX 變數類型之間的 1 對 1 對應關係,以便 ToNNX
和 ToLinen
可以自動執行轉換。
Flax 會保留此註冊表,並且它已經涵蓋了所有 Flax 的內建 Linen 集合。您可以使用 nnx.register_variable_name_type_pair
註冊 NNX 變數類型和 Linen 集合名稱的額外對應關係。
Linen 到 NNX#
對於 Linen 模組的任何集合,ToNNX
都會將其所有端點陣列(也稱為葉)轉換為 nnx.Variable
的子型別,可以是來自註冊表或動態自動建立。
(但是,我們仍然將整個集合保留為一個類別屬性,因為 Linen 模組在不同的集合中可能會有重複的名稱。)
class LinenMultiCollections(nn.Module):
out_dim: int
def setup(self):
self.w = self.param('w', nn.initializers.lecun_normal(), (x.shape[-1], self.out_dim))
self.b = self.param('b', nn.zeros_init(), (self.out_dim,))
self.count = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32))
def __call__(self, x):
if not self.is_initializing():
self.count.value += 1
y = x @ self.w + self.b
self.sow('intermediates', 'dot_sum', jnp.sum(y))
return y
x = jax.random.normal(jax.random.key(42), (2, 4))
model = bridge.lazy_init(bridge.ToNNX(LinenMultiCollections(3), rngs=nnx.Rngs(0)), x)
print(model.w) # Of type `nnx.Param` - note this is still under attribute `params`
print(model.b) # Of type `nnx.Param`
print(model.count) # Of type `counter` - auto-created type from the collection name
print(type(model.count))
y = model(x, mutable=True) # Linen's `sow()` needs `mutable=True` to trigger
print(model.dot_sum) # Of type `nnx.Intermediates`
Param(
value=Array([[ 0.35401407, 0.38010964, -0.20674096],
[-0.7356256 , 0.35613298, -0.5099556 ],
[-0.4783049 , 0.4310735 , 0.30137998],
[-0.6102254 , -0.2668519 , -1.053598 ]], dtype=float32)
)
Param(
value=Array([0., 0., 0.], dtype=float32)
)
counter(
value=Array(0, dtype=int32)
)
<class 'flax.nnx.bridge.variables.counter'>
(Intermediate(
value=Array(6.9329877, dtype=float32)
),)
您可以使用 nnx.split
快速區分不同類型的 NNX 變數。
當您只想將某些變數設定為可訓練時,這會很方便。
# Separate variables of different types with nnx.split
CountType = type(model.count)
static, params, counter, the_rest = nnx.split(model, nnx.Param, CountType, ...)
print('All Params:', list(params.keys()))
print('All Counters:', list(counter.keys()))
print('All the rest (intermediates and RNG keys):', list(the_rest.keys()))
model = nnx.merge(static, params, counter, the_rest) # You can merge them back at any time
y = model(x, mutable=True) # still works!
All Params: ['b', 'w']
All Counters: ['count']
All the rest (intermediates and RNG keys): ['dot_sum', 'rngs']
NNX 到 Linen#
如果您定義了自訂的 NNX 變數類型,您應該使用 nnx.register_variable_name_type_pair
註冊它們的名稱,以便它們能被歸類到所需的集合中。
class Count(nnx.Variable): pass
nnx.register_variable_name_type_pair('counts', Count, overwrite=True)
class NNXMultiCollections(nnx.Module):
def __init__(self, din, dout, rngs):
self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout)))
self.lora = nnx.LoRA(din, 3, dout, rngs=rngs)
self.count = Count(jnp.array(0))
def __call__(self, x):
self.count += 1
return (x @ self.w.value) + self.lora(x)
xkey, pkey, dkey = jax.random.split(jax.random.key(0), 3)
x = jax.random.normal(xkey, (2, 4))
model = bridge.to_linen(NNXMultiCollections, 4, 3)
var = model.init({'params': pkey, 'dropout': dkey}, x)
print('All Linen collections:', list(var.keys()))
print(var['params'])
All Linen collections: ['nnx', 'LoRAParam', 'counts', 'params']
{'w': Array([[ 0.2916921 , 0.22780475, 0.06553137],
[ 0.17487915, -0.34043145, 0.24764155],
[ 0.6420431 , 0.6220095 , -0.44769976],
[ 0.11161668, 0.83873135, -0.7446058 ]], dtype=float32)}
分割元數據#
Flax 使用一個元數據包裝盒來包裝原始的 JAX 陣列,以註解變數應該如何分片。
在 Linen 中,這是一個可選的功能,透過在初始化器上使用 nn.with_partitioning
來觸發(詳情請參閱 Linen 分割元數據指南)。在 NNX 中,由於所有 NNX 變數都已經被 nnx.Variable
類別包裝,該類別也會持有分片註解。
如果您使用內建的註解方法(即 Linen 的 nn.with_partitioning
和 NNX 的 nnx.with_partitioning
),bridge.ToNNX
和 bridge.ToLinen
API 將會自動轉換分片註解。
Linen 到 NNX#
即使您沒有在 Linen 模組中使用任何分割元數據,變數的 JAX 陣列也會被轉換為包裝著真實 JAX 陣列的 nnx.Variable
物件。
如果您使用 nn.with_partitioning
來註解您的 Linen 模組的變數,該註解將會被轉換為對應的 nnx.Variable
中的 .sharding
欄位。
然後,您可以使用 nnx.with_sharding_constraint
在 jax.jit
編譯的函式中,將陣列明確地放入已註解的分割中,以正確的分片初始化整個模型。
class LinenDotWithPartitioning(nn.Module):
out_dim: int
@nn.compact
def __call__(self, x):
w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(),
('in', 'out')),
(x.shape[-1], self.out_dim))
return x @ w
@nnx.jit
def create_sharded_nnx_module(x):
model = bridge.lazy_init(
bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x)
state = nnx.state(model)
sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state))
nnx.update(model, sharded_state)
return model
print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...')
mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)),
axis_names=('in', 'out'))
x = jax.random.normal(jax.random.key(42), (4, 32))
with mesh:
model = create_sharded_nnx_module(x)
print(type(model.w)) # `nnx.Param`
print(model.w.sharding) # The partition annotation attached with `w`
print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh
We have 8 fake JAX devices now to partition this model...
<class 'flax.nnx.variablelib.Param'>
('in', 'out')
NamedSharding(mesh=Mesh('in': 2, 'out': 4), spec=PartitionSpec('in', 'out'), memory_kind=unpinned_host)
NNX 到 Linen#
如果您沒有使用 nnx.Variable
的任何元數據功能(即沒有分片註解,沒有註冊的鉤子),轉換後的 Linen 模組不會在您的 NNX 變數中新增元數據包裝,您無需擔心這一點。
但是,如果您確實將分片註解添加到您的 NNX 變數中,ToLinen
會將它們轉換為名為 bridge.NNXMeta
的預設 Linen 分割元數據類別,並保留您放入 NNX 變數中的所有元數據。
與任何 Linen 元數據包裝一樣,您可以使用 linen.unbox()
來取得原始的 JAX 陣列樹。
class NNXDotWithParititioning(nnx.Module):
def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):
init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out'))
self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)))
def __call__(self, x: jax.Array):
return x @ self.w
x = jax.random.normal(jax.random.key(42), (4, 32))
@jax.jit
def create_sharded_variables(key, x):
model = bridge.to_linen(NNXDotWithParititioning, 32, 64)
variables = model.init(key, x)
# A `NNXMeta` wrapper of the underlying `nnx.Param`
assert type(variables['params']['w']) == bridge.NNXMeta
# The annotation coming from the `nnx.Param` => (in, out)
assert variables['params']['w'].metadata['sharding'] == ('in', 'out')
unboxed_variables = nn.unbox(variables)
variable_pspecs = nn.get_partition_spec(variables)
assert isinstance(unboxed_variables['params']['w'], jax.Array)
assert variable_pspecs['params']['w'] == jax.sharding.PartitionSpec('in', 'out')
sharded_vars = jax.tree.map(jax.lax.with_sharding_constraint,
nn.unbox(variables),
nn.get_partition_spec(variables))
return sharded_vars
with mesh:
variables = create_sharded_variables(jax.random.key(0), x)
# The underlying JAX array is sharded across the 2x4 mesh
print(variables['params']['w'].sharding)
NamedSharding(mesh=Mesh('in': 2, 'out': 4), spec=PartitionSpec('in', 'out'), memory_kind=unpinned_host)
提升的轉換#
一般而言,如果您想在 nnx.bridge
轉換後的模組上應用 Linen/NNX 風格的提升轉換,只需按照常用的 Linen/NNX 語法進行即可。
對於 Linen 風格的轉換,請注意 bridge.ToLinen
是頂層的模組類別,因此您可能只想將其用作轉換的第一個參數(在大多數情況下,該參數需要是一個 linen.Module
類別)。
Linen 到 NNX#
NNX 風格的提升轉換與 JAX 轉換相似,它們作用於函式。
class NNXVmapped(nnx.Module):
def __init__(self, out_dim: int, vmap_axis_size: int, rngs: nnx.Rngs):
self.linen_dot = nnx.bridge.ToNNX(nn.Dense(out_dim, use_bias=False), rngs=rngs)
self.vmap_axis_size = vmap_axis_size
def __call__(self, x):
@nnx.split_rngs(splits=self.vmap_axis_size)
@nnx.vmap(in_axes=(0, 0), axis_size=self.vmap_axis_size)
def vmap_fn(submodule, x):
return submodule(x)
return vmap_fn(self.linen_dot, x)
x = jax.random.normal(jax.random.key(0), (4, 32))
model = bridge.lazy_init(NNXVmapped(64, 4, rngs=nnx.Rngs(0)), x)
print(model.linen_dot.kernel.shape) # (4, 32, 64) - first axis with dim 4 got vmapped
y = model(x)
print(y.shape)
(4, 32, 64)
(4, 64)
NNX 到 Linen#
請注意,bridge.ToLinen
是頂層的模組類別,因此您可能只想將其用作轉換的第一個參數(在大多數情況下,該參數需要是一個 linen.Module
類別)。
此外,由於 bridge.ToLinen
引入了這個額外的 nnx
集合,當使用軸變換轉換(linen.vmap
、linen.scan
等)時,您需要標記它,以確保它們被傳遞到內部。
class LinenVmapped(nn.Module):
dout: int
@nn.compact
def __call__(self, x):
inner = nn.vmap(bridge.ToLinen, variable_axes={'params': 0, 'nnx': None}, split_rngs={'params': True}
)(nnx.Linear, args=(x.shape[-1], self.dout))
return inner(x)
x = jax.random.normal(jax.random.key(42), (4, 32))
model = LinenVmapped(64)
var = model.init(jax.random.key(0), x)
print(var['params']['VmapToLinen_0']['kernel'].shape) # (4, 32, 64) - leading dim 4 got vmapped
y = model.apply(var, x)
print(y.shape)
(4, 32, 64)
(4, 64)