從 Flax Linen 到 NNX 的演進#
本指南示範了 Flax Linen 和 Flax NNX 模型之間的差異,並提供並排的範例程式碼,以幫助您從 Flax Linen 遷移到 Flax NNX API。
本文檔主要教導如何將任意的 Flax Linen 程式碼轉換為 Flax NNX。如果您想要以「安全」的方式迭代轉換程式碼庫,請查看透過 nnx.bridge 一起使用 Flax NNX 和 Linen 指南。
為了充分利用本指南,強烈建議先閱讀Flax NNX 基礎知識文件,其中涵蓋了 nnx.Module
系統、Flax 轉換,以及帶有範例的 Functional API。
基本的 Module
定義#
Flax Linen 和 Flax NNX 都使用 Module
類別作為表達神經網路庫層的預設單元。在下面的範例中,您首先建立一個 Block
(透過繼承 Module
),它由一個帶有 dropout 和 ReLU 激活函數的線性層組成;然後,當建立 Model
(也是透過繼承 Module
)時,您將其用作子 Module
,它由 Block
和一個線性層組成。
Flax Linen 和 Flax NNX Module
物件之間有兩個根本差異
無狀態 vs 有狀態:
flax.linen.Module
(nn.Module
) 實例是無狀態的 - 變數是從純函數式的Module.init()
呼叫中返回,並單獨管理。flax.nnx.Module
則將其變數作為此 Python 物件的屬性擁有。惰性 vs 急切:
flax.linen.Module
僅在其看到輸入(惰性)時才分配空間來建立變數。flax.nnx.Module
實例會在它們被實例化時,在看到範例輸入(急切)之前就建立變數。Flax Linen 可以使用
@nn.compact
裝飾器在單個方法中定義模型,並使用來自輸入範例的形狀推斷。Flax NNXModule
通常要求額外的形狀資訊,以便在__init__
期間建立所有參數,並在__call__
方法中單獨定義計算。
import flax.linen as nn
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5, deterministic=not training)(x)
x = jax.nn.relu(x)
return x
class Model(nn.Module):
dmid: int
dout: int
@nn.compact
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = nn.Dense(self.dout)(x)
return x
from flax import nnx
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x):
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x
class Model(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
self.block = Block(din, dmid, rngs=rngs)
self.linear = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = self.block(x)
x = self.linear(x)
return x
變數建立#
接下來,讓我們討論實例化模型和初始化其參數
要為 Flax Linen 模型生成模型參數,請使用
jax.random.key
(doc) 以及模型應接受的一些範例輸入來呼叫flax.linen.Module.init
(nn.Module.init
) 方法。這會產生一個巢狀的 JAX 陣列(jax.Array
資料類型)字典,這些陣列將被攜帶和單獨維護。在 Flax NNX 中,模型參數會在您實例化模型時自動初始化,並且變數(
nnx.Variable
物件)會作為屬性儲存在nnx.Module
(或其子Module
)內部。您仍然需要為它提供一個 偽隨機數產生器 (PRNG) 金鑰,但該金鑰將被包裝在nnx.Rngs
類別中並儲存在內部,以便在需要時生成更多 PRNG 金鑰。
如果您想以無狀態、類似字典的方式存取 Flax NNX 模型參數,以進行檢查點儲存或模型手術,請查看Flax NNX 分割/合併 API(nnx.split
/ nnx.merge
)。
model = Model(256, 10)
sample_x = jnp.ones((1, 784))
variables = model.init(jax.random.key(0), sample_x, training=False)
params = variables["params"]
assert params['Dense_0']['bias'].shape == (10,)
assert params['Block_0']['Dense_0']['kernel'].shape == (784, 256)
model = Model(784, 256, 10, rngs=nnx.Rngs(0))
# Parameters were already initialized during model instantiation.
assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)
訓練步驟和編譯#
現在,讓我們繼續編寫訓練步驟並使用 JAX 即時編譯來編譯它。以下是 Flax Linen 和 Flax NNX 方法之間的一些差異。
編譯訓練步驟
Flax Linen 使用
@jax.jit
- 一個 JAX 轉換 - 來編譯訓練步驟。Flax NNX 使用
@nnx.jit
- 一個 Flax NNX 轉換(幾個轉換 API 之一,其行為類似於 JAX 轉換,但也與 Flax NNX 物件良好協作)。因此,儘管jax.jit
僅接受函數的純無狀態引數,nnx.jit
允許引數為有狀態的 NNX 模組。這大大減少了訓練步驟所需的程式碼行數。
取得梯度
同樣地,Flax Linen 使用
jax.grad
(一個用於 自動微分 的 JAX 轉換)來返回原始的梯度字典。Flax NNX 使用
nnx.grad
(一個 Flax NNX 轉換)來將 NNX 模組的梯度作為nnx.State
字典返回。如果您想將常規的jax.grad
與 Flax NNX 一起使用,您需要使用Flax NNX 分割/合併 API。
最佳化器
如果您已經在使用 Optax 優化器,例如
optax.adamw
(而不是這裡顯示的原始jax.tree.map
計算),請參考 Flax NNX 基礎指南中的nnx.Optimizer
範例,這是一種更簡潔的方式來訓練和更新您的模型。
每個訓練步驟中的模型更新
Flax Linen 訓練步驟需要返回一個參數 pytree,作為下一步的輸入。
Flax NNX 訓練步驟不需要返回任何內容,因為
model
已經在nnx.jit
內就地更新。此外,
nnx.Module
物件是有狀態的,並且Module
會自動追蹤其中的幾個項目,例如 PRNG 金鑰和BatchNorm
統計數據。這就是為什麼您不需要在每個步驟都明確傳入 PRNG 金鑰的原因。另請注意,您可以使用nnx.reseed
來重設其底層的 PRNG 狀態。
Dropout 行為
在 Flax Linen 中,您需要明確定義並傳入
training
參數來控制flax.linen.Dropout
(nn.Dropout
) 的行為,也就是它的deterministic
標誌,這表示只有在training=True
時才會發生隨機 dropout。在 Flax NNX 中,您可以呼叫
model.train()
(flax.nnx.Module.train()
) 來自動將nnx.Dropout
切換到訓練模式。相反地,您可以呼叫model.eval()
(flax.nnx.Module.eval()
) 來關閉訓練模式。您可以在其 API 參考中了解更多關於nnx.Module.train
的作用。
...
@jax.jit
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
return params
model.train() # Sets ``deterministic=False` under the hood for nnx.Dropout
@nnx.jit
def train_step(model, inputs, labels):
def loss_fn(model):
logits = model(inputs)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.GraphState.merge(params, rest))
集合和變數類型#
Flax Linen 和 NNX API 之間的一個主要差異在於它們如何將變數分組到類別中。Flax Linen 使用不同的集合,而 Flax NNX 則因為所有變數都應為頂層的 Python 屬性,所以您會使用不同的變數類型。
在 Flax NNX 中,您可以自由建立自己的變數類型作為 nnx.Variable
的子類別。
對於所有內建的 Flax Linen 層和集合,Flax NNX 已經建立對應的層和變數類型。例如
flax.linen.Dense
(nn.Dense
) 建立params
->nnx.Linear
建立 :class:`nnx.Param<flax.nnx.Param>`。flax.linen.BatchNorm
(nn.BatchNorm
) 建立batch_stats
->nnx.BatchNorm
建立nnx.BatchStats
。flax.linen.Module.sow()
建立intermediates
->nnx.Module.sow()
建立nnx.Intermediaries
。在 Flax NNX 中,您也可以簡單地將中間值指定給
nnx.Module
屬性來取得中間值 - 例如,self.sowed = nnx.Intermediates(x)
。這會類似於 Flax Linen 的self.variable('intermediates' 'sowed', lambda: x)
。
class Block(nn.Module):
features: int
def setup(self):
self.dense = nn.Dense(self.features)
self.batchnorm = nn.BatchNorm(momentum=0.99)
self.count = self.variable('counter', 'count',
lambda: jnp.zeros((), jnp.int32))
@nn.compact
def __call__(self, x, training: bool):
x = self.dense(x)
x = self.batchnorm(x, use_running_average=not training)
self.count.value += 1
x = jax.nn.relu(x)
return x
x = jax.random.normal(jax.random.key(0), (2, 4))
model = Block(4)
variables = model.init(jax.random.key(0), x, training=True)
variables['params']['dense']['kernel'].shape # (4, 4)
variables['batch_stats']['batchnorm']['mean'].shape # (4, )
variables['counter']['count'] # 1
class Counter(nnx.Variable): pass
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.batchnorm = nnx.BatchNorm(
num_features=out_features, momentum=0.99, rngs=rngs
)
self.count = Counter(jnp.array(0))
def __call__(self, x):
x = self.linear(x)
x = self.batchnorm(x)
self.count += 1
x = jax.nn.relu(x)
return x
model = Block(4, 4, rngs=nnx.Rngs(0))
model.linear.kernel # Param(value=...)
model.batchnorm.mean # BatchStat(value=...)
model.count # Counter(value=...)
如果您想要從變數的 pytree 中提取特定的陣列
在 Flax Linen 中,您可以存取特定的字典路徑。
在 Flax NNX 中,您可以使用
nnx.split
來區分 Flax NNX 中的類型。下面的程式碼是一個簡單的範例,將變數按其類型分開 - 請查看 Flax NNX 篩選器指南以獲取更複雜的篩選表達式。
params, batch_stats, counter = (
variables['params'], variables['batch_stats'], variables['counter'])
params.keys() # ['dense', 'batchnorm']
batch_stats.keys() # ['batchnorm']
counter.keys() # ['count']
# ... make arbitrary modifications ...
# Merge back with raw dict to carry on:
variables = {'params': params, 'batch_stats': batch_stats, 'counter': counter}
graphdef, params, batch_stats, count = nnx.split(
model, nnx.Param, nnx.BatchStat, Counter)
params.keys() # ['batchnorm', 'linear']
batch_stats.keys() # ['batchnorm']
count.keys() # ['count']
# ... make arbitrary modifications ...
# Merge back with ``nnx.merge`` to carry on:
model = nnx.merge(graphdef, params, batch_stats, count)
使用多種方法#
在本節中,您將學習如何在 Flax Linen 和 Flax NNX 中使用多種方法。作為範例,您將實作一個具有三種方法的自動編碼器模型:encode
、decode
和 __call__
。
定義編碼器和解碼器層
在 Flax Linen 中,與之前一樣,定義層時無需傳入輸入形狀,因為
flax.linen.Module
參數將使用形狀推斷延遲初始化。在 Flax NNX 中,您必須傳入輸入形狀,因為
nnx.Module
參數將在沒有形狀推斷的情況下主動初始化。
class AutoEncoder(nn.Module):
embed_dim: int
output_dim: int
def setup(self):
self.encoder = nn.Dense(self.embed_dim)
self.decoder = nn.Dense(self.output_dim)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
model = AutoEncoder(256, 784)
variables = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):
def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))
變數結構如下
# variables['params']
{
decoder: {
bias: (784,),
kernel: (256, 784),
},
encoder: {
bias: (256,),
kernel: (784, 256),
},
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
},
'encoder': {
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
})
呼叫 __call__
以外的方法
在 Flax Linen 中,您仍然需要使用
apply
API。在 Flax NNX 中,您可以直接呼叫方法。
z = model.apply(variables, x=jnp.ones((1, 784)), method="encode")
z = model.encode(jnp.ones((1, 784)))
轉換#
Flax Linen 和 Flax NNX 轉換都提供自己的轉換集,這些轉換以可以使用 Module
物件的方式封裝 JAX 轉換。
Flax Linen 中的大多數轉換,例如 grad
或 jit
,在 Flax NNX 中沒有太多變化。但是,例如,如果您嘗試對層執行 scan
,如下一節所述,程式碼會大不相同。
讓我們先從一個範例開始
首先,定義一個
RNNCell
Module
,其中將包含 RNN 單一步驟的邏輯。定義一個
initial_state
方法,該方法將用於初始化 RNN 的狀態(又名carry
)。與jax.lax.scan
(API 文件) 一樣,RNNCell.__call__
方法將是一個接受 carry 和輸入並返回新 carry 和輸出的函數。在這種情況下,carry 和輸出是相同的。
class RNNCell(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = nn.Dense(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
def __init__(self, input_size, hidden_size, rngs):
self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = self.linear(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
接下來,定義一個 RNN
Module
,其中將包含整個 RNN 的邏輯。
在 Flax Linen 中
您將使用
flax.linen.scan
(nn.scan
) 來定義一個新的臨時類型,該類型封裝RNNCell
。在此過程中,您還將:1) 指示nn.scan
廣播params
集合 (所有步驟共享相同的參數),並且不要分割params
PRNG 串流 (以便所有步驟都使用相同的參數初始化);以及,最後 2) 指定您希望 scan 在輸入的第二個軸上執行,並沿著第二個軸堆疊輸出。然後,您將立即使用此臨時類型來建立「提升」的
RNNCell
的實例,並使用它來建立carry
,並執行__call__
方法,該方法將在序列上執行scan
。
在 Flax NNX 中
您將建立一個
scan
函式 (scan_fn
),它會使用在__init__
中定義的RNNCell
來掃描序列,並明確設定in_axes=(nnx.Carry, None, 1)
。nnx.Carry
表示carry
參數將會是 carry,None
表示cell
將會廣播至所有步驟,而1
表示x
將會沿著軸 1 掃描。
class RNN(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, x):
rnn = nn.scan(
RNNCell, variable_broadcast='params',
split_rngs={'params': False}, in_axes=1, out_axes=1
)(self.hidden_size)
carry = rnn.initial_state(x.shape[0])
carry, y = rnn(carry, x)
return y
x = jnp.ones((3, 12, 32))
model = RNN(64)
variables = model.init(jax.random.key(0), x=jnp.ones((3, 12, 32)))
y = model.apply(variables, x=jnp.ones((3, 12, 32)))
class RNN(nnx.Module):
def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
self.hidden_size = hidden_size
self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)
def __call__(self, x):
scan_fn = lambda carry, cell, x: cell(carry, x)
carry = self.cell.initial_state(x.shape[0])
carry, y = nnx.scan(
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
)(carry, self.cell, x)
return y
x = jnp.ones((3, 12, 32))
model = RNN(x.shape[2], 64, rngs=nnx.Rngs(0))
y = model(x)
掃描層#
一般來說,Flax Linen 和 Flax NNX 的轉換應該看起來相同。然而,Flax NNX 轉換 的設計更接近其底層的 JAX 對應物,因此我們在某些 Linen 提升的轉換中捨棄了一些假設。這個掃描層的使用案例將會是展示它的好例子。
掃描層是一種技術,您將輸入透過一連串 N 個重複的層來執行,並將每一層的輸出作為下一層的輸入。這種模式可以顯著減少大型模型的編譯時間。在下面的範例中,您將在頂層的 MLP
Module
中重複 Block
Module
5 次。
在 Flax Linen 中,您將
flax.linen.scan
(nn.scan
) 轉換應用於Block
nn.Module
上,以建立一個更大的ScanBlock
nn.Module
,其中包含 5 個Block
nn.Module
物件。它會在初始化時自動建立一個形狀為(5, 64, 64)
的大型參數,並在呼叫時迭代每個(64, 64)
切片共 5 次,就像jax.lax.scan
( API 文件 ) 一樣。近距離來看,在這個模型的邏輯中,實際上並不需要在初始化時進行
jax.lax.scan
操作。那裡發生的事情更像是jax.vmap
操作 - 您會得到一個接受(in_dim, out_dim)
的Block
子Module
,並且您將其「vmap」遍歷num_layers
次,以建立一個較大的陣列。在 Flax NNX 中,您可以利用模型初始化和執行程式碼完全解耦的事實,而是使用
nnx.vmap
轉換來初始化底層的Block
參數,以及nnx.scan
轉換來透過它們執行模型輸入。
如需更多關於 Flax NNX 轉換的資訊,請參閱 轉換指南。
class Block(nn.Module):
features: int
training: bool
@nn.compact
def __call__(self, x, _):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
x = jax.nn.relu(x)
return x, None
class MLP(nn.Module):
features: int
num_layers: int
@nn.compact
def __call__(self, x, training: bool):
ScanBlock = nn.scan(
Block, variable_axes={'params': 0}, split_rngs={'params': True},
length=self.num_layers)
y, _ = ScanBlock(self.features, training)(x, None)
return y
model = MLP(64, num_layers=5)
class Block(nnx.Module):
def __init__(self, input_dim, features, rngs):
self.linear = nnx.Linear(input_dim, features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x: jax.Array): # No need to require a second input!
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x # No need to return a second output!
class MLP(nnx.Module):
def __init__(self, features, num_layers, rngs):
@nnx.split_rngs(splits=num_layers)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_block(rngs: nnx.Rngs):
return Block(features, features, rngs=rngs)
self.blocks = create_block(rngs)
self.num_layers = num_layers
def __call__(self, x):
@nnx.split_rngs(splits=self.num_layers)
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def forward(x, model):
x = model(x)
return x
return forward(x, self.blocks)
model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))
在上面的 Flax NNX 範例中還有一些其他細節需要解釋
`@nnx.split_rngs` 修飾器: Flax NNX 轉換完全與 PRNG 狀態無關,這使得它們的行為更像 JAX 轉換,但與處理 PRNG 狀態的 Flax Linen 轉換不同。為了重新獲得此功能,
nnx.split_rngs
修飾器允許您在將nnx.Rngs
傳遞給修飾的函式之前將其分割,並在之後「降低」它們,以便它們可以在外部使用。在這裡,您分割了 PRNG 金鑰,因為如果每個內部操作都需要自己的金鑰,則
jax.vmap
和jax.lax.scan
需要 PRNG 金鑰的清單。因此,對於MLP
內部的 5 個層,您會分割並從其參數中提供 5 個不同的 PRNG 金鑰,然後再向下傳遞到 JAX 轉換。請注意,實際上
create_block()
知道它需要建立 5 層,正是因為 它看到 5 個 PRNG 金鑰,因為in_axes=(0,)
表示vmap
將會查看第一個參數的第一個維度,以了解它將要對其進行映射的大小。對於
forward()
也是如此,它會查看第一個參數(又名model
)內的變數,以找出它需要掃描多少次。nnx.split_rngs
在這裡實際上會分割model
內的 PRNG 狀態。(如果Block
Module
沒有 dropout,則您不需要nnx.split_rngs
行,因為它無論如何都不會消耗任何 PRNG 金鑰。)
為什麼 Flax NNX 中的 Block Module 不需要接受和傳回額外的虛擬值: 這是
jax.lax.scan
的要求 (API 文件。Flax NNX 簡化了這一點,因此如果您設定out_axes=nnx.Carry
而不是預設的(nnx.Carry, 0)
,您現在可以選擇忽略第二個輸出。這是 Flax NNX 轉換與 JAX 轉換 API 不同的少數情況之一。
上面的 Flax NNX 範例中有更多程式碼行,但它們更精確地表達了每次發生的情況。由於 Flax NNX 轉換變得更接近 JAX 轉換 API,因此建議您在使用它們的 Flax NNX 對應物 之前,先對底層的 JAX 轉換 有很好的了解。
現在檢查兩邊的變數 pytree
# variables = model.init(key, x=jnp.ones((1, 64)), training=True)
# variables['params']
{
ScanBlock_0: {
Dense_0: {
bias: (5, 64),
kernel: (5, 64, 64),
},
},
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
'blocks': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
}
}
})
在 Flax NNX 中使用 TrainState
#
Flax Linen 有一個方便的 TrainState
資料類別,用於捆綁模型、參數和最佳化器。在 Flax NNX 中,這並不是真的有必要。在本節中,您將學習如何在任何向後相容性需求下,圍繞 TrainState
建構您的 Flax NNX 程式碼。
在 Flax NNX 中
您必須先在模型上呼叫
nnx.split
,以取得單獨的nnx.GraphDef
和nnx.State
物件。您可以傳入
nnx.Param
,以將所有可訓練的參數篩選到單一nnx.State
中,並傳入...
作為剩餘的變數。您還需要子類化
TrainState
,以為其他變數新增一個欄位。然後,您可以傳入
nnx.GraphDef.apply
作為apply
函式、nnx.State
作為參數和其他變數,以及一個最佳化器作為TrainState
建構子的參數。
請注意,nnx.GraphDef.apply
會將 nnx.State
物件作為參數傳入,並傳回一個可呼叫的函式。此函式可以在輸入上呼叫,以輸出模型的 logits,以及更新的 nnx.GraphDef
和 nnx.State
物件。請注意,由於您沒有將 Flax NNX Modules 傳遞到 train_step
中,因此下方使用了 @jax.jit
。
from flax.training import train_state
sample_x = jnp.ones((1, 784))
model = nn.Dense(features=10)
params = model.init(jax.random.key(0), sample_x)['params']
state = train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=optax.adam(1e-3)
)
@jax.jit
def train_step(key, state, inputs, labels):
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
inputs, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
return state
from flax.training import train_state
model = nnx.Linear(784, 10, rngs=nnx.Rngs(0))
model.train() # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)
class TrainState(train_state.TrainState):
other_variables: nnx.State
state = TrainState.create(
apply_fn=graphdef.apply,
params=params,
other_variables=other_variables,
tx=optax.adam(1e-3)
)
@jax.jit
def train_step(state, inputs, labels):
def loss_fn(params, other_variables):
logits, (graphdef, new_state) = state.apply_fn(
params,
other_variables
)(inputs) # <== inputs
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params, state.other_variables)
state = state.apply_gradients(grads=grads)
return state