從 Flax Linen 到 NNX 的演進#

本指南示範了 Flax Linen 和 Flax NNX 模型之間的差異,並提供並排的範例程式碼,以幫助您從 Flax Linen 遷移到 Flax NNX API。

本文檔主要教導如何將任意的 Flax Linen 程式碼轉換為 Flax NNX。如果您想要以「安全」的方式迭代轉換程式碼庫,請查看透過 nnx.bridge 一起使用 Flax NNX 和 Linen 指南。

為了充分利用本指南,強烈建議先閱讀Flax NNX 基礎知識文件,其中涵蓋了 nnx.Module 系統、Flax 轉換,以及帶有範例的 Functional API

基本的 Module 定義#

Flax Linen 和 Flax NNX 都使用 Module 類別作為表達神經網路庫層的預設單元。在下面的範例中,您首先建立一個 Block (透過繼承 Module),它由一個帶有 dropout 和 ReLU 激活函數的線性層組成;然後,當建立 Model (也是透過繼承 Module)時,您將其用作子 Module,它由 Block 和一個線性層組成。

Flax Linen 和 Flax NNX Module 物件之間有兩個根本差異

  • 無狀態 vs 有狀態flax.linen.Module (nn.Module) 實例是無狀態的 - 變數是從純函數式的 Module.init() 呼叫中返回,並單獨管理。flax.nnx.Module 則將其變數作為此 Python 物件的屬性擁有。

  • 惰性 vs 急切flax.linen.Module 僅在其看到輸入(惰性)時才分配空間來建立變數。 flax.nnx.Module 實例會在它們被實例化時,在看到範例輸入(急切)之前就建立變數。

  • Flax Linen 可以使用 @nn.compact 裝飾器在單個方法中定義模型,並使用來自輸入範例的形狀推斷。Flax NNX Module 通常要求額外的形狀資訊,以便在 __init__ 期間建立所有參數,並在 __call__ 方法中單獨定義計算。

import flax.linen as nn

class Block(nn.Module):
  features: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5, deterministic=not training)(x)
    x = jax.nn.relu(x)
    return x

class Model(nn.Module):
  dmid: int
  dout: int

  @nn.compact
  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = nn.Dense(self.dout)(x)
    return x
from flax import nnx

class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x):
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x

class Model(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
    self.block = Block(din, dmid, rngs=rngs)
    self.linear = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = self.block(x)
    x = self.linear(x)
    return x

變數建立#

接下來,讓我們討論實例化模型和初始化其參數

  • 要為 Flax Linen 模型生成模型參數,請使用 jax.random.key (doc) 以及模型應接受的一些範例輸入來呼叫 flax.linen.Module.init (nn.Module.init) 方法。這會產生一個巢狀的 JAX 陣列jax.Array 資料類型)字典,這些陣列將被攜帶和單獨維護。

  • 在 Flax NNX 中,模型參數會在您實例化模型時自動初始化,並且變數(nnx.Variable 物件)會作為屬性儲存在 nnx.Module(或其子 Module)內部。您仍然需要為它提供一個 偽隨機數產生器 (PRNG) 金鑰,但該金鑰將被包裝在 nnx.Rngs 類別中並儲存在內部,以便在需要時生成更多 PRNG 金鑰。

如果您想以無狀態、類似字典的方式存取 Flax NNX 模型參數,以進行檢查點儲存或模型手術,請查看Flax NNX 分割/合併 APInnx.split / nnx.merge)。

model = Model(256, 10)
sample_x = jnp.ones((1, 784))
variables = model.init(jax.random.key(0), sample_x, training=False)
params = variables["params"]

assert params['Dense_0']['bias'].shape == (10,)
assert params['Block_0']['Dense_0']['kernel'].shape == (784, 256)
model = Model(784, 256, 10, rngs=nnx.Rngs(0))


# Parameters were already initialized during model instantiation.

assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)

訓練步驟和編譯#

現在,讓我們繼續編寫訓練步驟並使用 JAX 即時編譯來編譯它。以下是 Flax Linen 和 Flax NNX 方法之間的一些差異。

編譯訓練步驟

  • Flax Linen 使用 @jax.jit - 一個 JAX 轉換 - 來編譯訓練步驟。

  • Flax NNX 使用 @nnx.jit - 一個 Flax NNX 轉換(幾個轉換 API 之一,其行為類似於 JAX 轉換,但也與 Flax NNX 物件良好協作)。因此,儘管 jax.jit 僅接受函數的純無狀態引數,nnx.jit 允許引數為有狀態的 NNX 模組。這大大減少了訓練步驟所需的程式碼行數。

取得梯度

  • 同樣地,Flax Linen 使用 jax.grad(一個用於 自動微分 的 JAX 轉換)來返回原始的梯度字典。

  • Flax NNX 使用 nnx.grad(一個 Flax NNX 轉換)來將 NNX 模組的梯度作為 nnx.State 字典返回。如果您想將常規的 jax.grad 與 Flax NNX 一起使用,您需要使用Flax NNX 分割/合併 API

最佳化器

  • 如果您已經在使用 Optax 優化器,例如 optax.adamw (而不是這裡顯示的原始 jax.tree.map 計算),請參考 Flax NNX 基礎指南中的 nnx.Optimizer 範例,這是一種更簡潔的方式來訓練和更新您的模型。

每個訓練步驟中的模型更新

  • Flax Linen 訓練步驟需要返回一個參數 pytree,作為下一步的輸入。

  • Flax NNX 訓練步驟不需要返回任何內容,因為 model 已經在 nnx.jit 內就地更新。

  • 此外,nnx.Module 物件是有狀態的,並且 Module 會自動追蹤其中的幾個項目,例如 PRNG 金鑰和 BatchNorm 統計數據。這就是為什麼您不需要在每個步驟都明確傳入 PRNG 金鑰的原因。另請注意,您可以使用 nnx.reseed 來重設其底層的 PRNG 狀態。

Dropout 行為

  • 在 Flax Linen 中,您需要明確定義並傳入 training 參數來控制 flax.linen.Dropout (nn.Dropout) 的行為,也就是它的 deterministic 標誌,這表示只有在 training=True 時才會發生隨機 dropout。

  • 在 Flax NNX 中,您可以呼叫 model.train() (flax.nnx.Module.train()) 來自動將 nnx.Dropout 切換到訓練模式。相反地,您可以呼叫 model.eval() (flax.nnx.Module.eval()) 來關閉訓練模式。您可以在其 API 參考中了解更多關於 nnx.Module.train 的作用。

...

@jax.jit
def train_step(key, params, inputs, labels):
  def loss_fn(params):
    logits = model.apply(
      {'params': params},
      inputs, training=True, # <== inputs
      rngs={'dropout': key}
    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)

  params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
  return params
model.train() # Sets ``deterministic=False` under the hood for nnx.Dropout

@nnx.jit
def train_step(model, inputs, labels):
  def loss_fn(model):
    logits = model(inputs)




    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = nnx.grad(loss_fn)(model)
  _, params, rest = nnx.split(model, nnx.Param, ...)
  params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
  nnx.update(model, nnx.GraphState.merge(params, rest))

集合和變數類型#

Flax Linen 和 NNX API 之間的一個主要差異在於它們如何將變數分組到類別中。Flax Linen 使用不同的集合,而 Flax NNX 則因為所有變數都應為頂層的 Python 屬性,所以您會使用不同的變數類型。

在 Flax NNX 中,您可以自由建立自己的變數類型作為 nnx.Variable 的子類別。

對於所有內建的 Flax Linen 層和集合,Flax NNX 已經建立對應的層和變數類型。例如

  • flax.linen.Dense (nn.Dense) 建立 params -> nnx.Linear 建立 :class:`nnx.Param<flax.nnx.Param>`。

  • flax.linen.BatchNorm (nn.BatchNorm) 建立 batch_stats -> nnx.BatchNorm 建立 nnx.BatchStats

  • flax.linen.Module.sow() 建立 intermediates -> nnx.Module.sow() 建立 nnx.Intermediaries

  • 在 Flax NNX 中,您也可以簡單地將中間值指定給 nnx.Module 屬性來取得中間值 - 例如,self.sowed = nnx.Intermediates(x)。這會類似於 Flax Linen 的 self.variable('intermediates' 'sowed', lambda: x)

class Block(nn.Module):
  features: int
  def setup(self):
    self.dense = nn.Dense(self.features)
    self.batchnorm = nn.BatchNorm(momentum=0.99)
    self.count = self.variable('counter', 'count',
                                lambda: jnp.zeros((), jnp.int32))


  @nn.compact
  def __call__(self, x, training: bool):
    x = self.dense(x)
    x = self.batchnorm(x, use_running_average=not training)
    self.count.value += 1
    x = jax.nn.relu(x)
    return x

x = jax.random.normal(jax.random.key(0), (2, 4))
model = Block(4)
variables = model.init(jax.random.key(0), x, training=True)
variables['params']['dense']['kernel'].shape         # (4, 4)
variables['batch_stats']['batchnorm']['mean'].shape  # (4, )
variables['counter']['count']                        # 1
class Counter(nnx.Variable): pass

class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(
      num_features=out_features, momentum=0.99, rngs=rngs
    )
    self.count = Counter(jnp.array(0))

  def __call__(self, x):
    x = self.linear(x)
    x = self.batchnorm(x)
    self.count += 1
    x = jax.nn.relu(x)
    return x



model = Block(4, 4, rngs=nnx.Rngs(0))

model.linear.kernel   # Param(value=...)
model.batchnorm.mean  # BatchStat(value=...)
model.count           # Counter(value=...)

如果您想要從變數的 pytree 中提取特定的陣列

  • 在 Flax Linen 中,您可以存取特定的字典路徑。

  • 在 Flax NNX 中,您可以使用 nnx.split 來區分 Flax NNX 中的類型。下面的程式碼是一個簡單的範例,將變數按其類型分開 - 請查看 Flax NNX 篩選器指南以獲取更複雜的篩選表達式。

params, batch_stats, counter = (
  variables['params'], variables['batch_stats'], variables['counter'])
params.keys()       # ['dense', 'batchnorm']
batch_stats.keys()  # ['batchnorm']
counter.keys()      # ['count']

# ... make arbitrary modifications ...
# Merge back with raw dict to carry on:
variables = {'params': params, 'batch_stats': batch_stats, 'counter': counter}
graphdef, params, batch_stats, count = nnx.split(
  model, nnx.Param, nnx.BatchStat, Counter)
params.keys()       # ['batchnorm', 'linear']
batch_stats.keys()  # ['batchnorm']
count.keys()        # ['count']

# ... make arbitrary modifications ...
# Merge back with ``nnx.merge`` to carry on:
model = nnx.merge(graphdef, params, batch_stats, count)

使用多種方法#

在本節中,您將學習如何在 Flax Linen 和 Flax NNX 中使用多種方法。作為範例,您將實作一個具有三種方法的自動編碼器模型:encodedecode__call__

定義編碼器和解碼器層

  • 在 Flax Linen 中,與之前一樣,定義層時無需傳入輸入形狀,因為 flax.linen.Module 參數將使用形狀推斷延遲初始化。

  • 在 Flax NNX 中,您必須傳入輸入形狀,因為 nnx.Module 參數將在沒有形狀推斷的情況下主動初始化。

class AutoEncoder(nn.Module):
  embed_dim: int
  output_dim: int

  def setup(self):
    self.encoder = nn.Dense(self.embed_dim)
    self.decoder = nn.Dense(self.output_dim)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

model = AutoEncoder(256, 784)
variables = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):



  def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
    self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
    self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))

變數結構如下

# variables['params']
{
  decoder: {
      bias: (784,),
      kernel: (256, 784),
  },
  encoder: {
      bias: (256,),
      kernel: (784, 256),
  },
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
  'decoder': {
    'bias': VariableState(type=Param, value=(784,)),
    'kernel': VariableState(type=Param, value=(256, 784))
  },
  'encoder': {
    'bias': VariableState(type=Param, value=(256,)),
    'kernel': VariableState(type=Param, value=(784, 256))
  }
})

呼叫 __call__ 以外的方法

  • 在 Flax Linen 中,您仍然需要使用 apply API。

  • 在 Flax NNX 中,您可以直接呼叫方法。

z = model.apply(variables, x=jnp.ones((1, 784)), method="encode")
z = model.encode(jnp.ones((1, 784)))

轉換#

Flax Linen 和 Flax NNX 轉換都提供自己的轉換集,這些轉換以可以使用 Module 物件的方式封裝 JAX 轉換

Flax Linen 中的大多數轉換,例如 gradjit,在 Flax NNX 中沒有太多變化。但是,例如,如果您嘗試對層執行 scan,如下一節所述,程式碼會大不相同。

讓我們先從一個範例開始

  • 首先,定義一個 RNNCell Module,其中將包含 RNN 單一步驟的邏輯。

  • 定義一個 initial_state 方法,該方法將用於初始化 RNN 的狀態(又名 carry)。與 jax.lax.scan (API 文件) 一樣,RNNCell.__call__ 方法將是一個接受 carry 和輸入並返回新 carry 和輸出的函數。在這種情況下,carry 和輸出是相同的。

class RNNCell(nn.Module):
  hidden_size: int


  @nn.compact
  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = nn.Dense(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
  def __init__(self, input_size, hidden_size, rngs):
    self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = self.linear(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))

接下來,定義一個 RNN Module,其中將包含整個 RNN 的邏輯。

在 Flax Linen 中

  • 您將使用 flax.linen.scan (nn.scan) 來定義一個新的臨時類型,該類型封裝 RNNCell。在此過程中,您還將:1) 指示 nn.scan 廣播 params 集合 (所有步驟共享相同的參數),並且不要分割 params PRNG 串流 (以便所有步驟都使用相同的參數初始化);以及,最後 2) 指定您希望 scan 在輸入的第二個軸上執行,並沿著第二個軸堆疊輸出。

  • 然後,您將立即使用此臨時類型來建立「提升」的 RNNCell 的實例,並使用它來建立 carry,並執行 __call__ 方法,該方法將在序列上執行 scan

在 Flax NNX 中

  • 您將建立一個 scan 函式 (scan_fn),它會使用在 __init__ 中定義的 RNNCell 來掃描序列,並明確設定 in_axes=(nnx.Carry, None, 1)nnx.Carry 表示 carry 參數將會是 carry,None 表示 cell 將會廣播至所有步驟,而 1 表示 x 將會沿著軸 1 掃描。

class RNN(nn.Module):
  hidden_size: int

  @nn.compact
  def __call__(self, x):
    rnn = nn.scan(
      RNNCell, variable_broadcast='params',
      split_rngs={'params': False}, in_axes=1, out_axes=1
    )(self.hidden_size)
    carry = rnn.initial_state(x.shape[0])
    carry, y = rnn(carry, x)

    return y

x = jnp.ones((3, 12, 32))
model = RNN(64)
variables = model.init(jax.random.key(0), x=jnp.ones((3, 12, 32)))
y = model.apply(variables, x=jnp.ones((3, 12, 32)))
class RNN(nnx.Module):
  def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
    self.hidden_size = hidden_size
    self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)

  def __call__(self, x):
    scan_fn = lambda carry, cell, x: cell(carry, x)
    carry = self.cell.initial_state(x.shape[0])
    carry, y = nnx.scan(
      scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
    )(carry, self.cell, x)

    return y

x = jnp.ones((3, 12, 32))
model = RNN(x.shape[2], 64, rngs=nnx.Rngs(0))

y = model(x)

掃描層#

一般來說,Flax Linen 和 Flax NNX 的轉換應該看起來相同。然而,Flax NNX 轉換 的設計更接近其底層的 JAX 對應物,因此我們在某些 Linen 提升的轉換中捨棄了一些假設。這個掃描層的使用案例將會是展示它的好例子。

掃描層是一種技術,您將輸入透過一連串 N 個重複的層來執行,並將每一層的輸出作為下一層的輸入。這種模式可以顯著減少大型模型的編譯時間。在下面的範例中,您將在頂層的 MLP Module 中重複 Block Module 5 次。

  • 在 Flax Linen 中,您將 flax.linen.scan (nn.scan) 轉換應用於 Block nn.Module 上,以建立一個更大的 ScanBlock nn.Module,其中包含 5 個 Block nn.Module 物件。它會在初始化時自動建立一個形狀為 (5, 64, 64) 的大型參數,並在呼叫時迭代每個 (64, 64) 切片共 5 次,就像 jax.lax.scan ( API 文件 ) 一樣。

  • 近距離來看,在這個模型的邏輯中,實際上並不需要在初始化時進行 jax.lax.scan 操作。那裡發生的事情更像是 jax.vmap 操作 - 您會得到一個接受 (in_dim, out_dim)BlockModule,並且您將其「vmap」遍歷 num_layers 次,以建立一個較大的陣列。

  • 在 Flax NNX 中,您可以利用模型初始化和執行程式碼完全解耦的事實,而是使用 nnx.vmap 轉換來初始化底層的 Block 參數,以及 nnx.scan 轉換來透過它們執行模型輸入。

如需更多關於 Flax NNX 轉換的資訊,請參閱 轉換指南

class Block(nn.Module):
  features: int
  training: bool

  @nn.compact
  def __call__(self, x, _):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5)(x, deterministic=not self.training)
    x = jax.nn.relu(x)
    return x, None

class MLP(nn.Module):
  features: int
  num_layers: int




  @nn.compact
  def __call__(self, x, training: bool):
    ScanBlock = nn.scan(
      Block, variable_axes={'params': 0}, split_rngs={'params': True},
      length=self.num_layers)

    y, _ = ScanBlock(self.features, training)(x, None)
    return y

model = MLP(64, num_layers=5)
class Block(nnx.Module):
  def __init__(self, input_dim, features, rngs):
    self.linear = nnx.Linear(input_dim, features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x: jax.Array):  # No need to require a second input!
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x   # No need to return a second output!

class MLP(nnx.Module):
  def __init__(self, features, num_layers, rngs):
    @nnx.split_rngs(splits=num_layers)
    @nnx.vmap(in_axes=(0,), out_axes=0)
    def create_block(rngs: nnx.Rngs):
      return Block(features, features, rngs=rngs)

    self.blocks = create_block(rngs)
    self.num_layers = num_layers

  def __call__(self, x):
    @nnx.split_rngs(splits=self.num_layers)
    @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
    def forward(x, model):
      x = model(x)
      return x

    return forward(x, self.blocks)

model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))

在上面的 Flax NNX 範例中還有一些其他細節需要解釋

  • `@nnx.split_rngs` 修飾器: Flax NNX 轉換完全與 PRNG 狀態無關,這使得它們的行為更像 JAX 轉換,但與處理 PRNG 狀態的 Flax Linen 轉換不同。為了重新獲得此功能,nnx.split_rngs 修飾器允許您在將 nnx.Rngs 傳遞給修飾的函式之前將其分割,並在之後「降低」它們,以便它們可以在外部使用。

    • 在這裡,您分割了 PRNG 金鑰,因為如果每個內部操作都需要自己的金鑰,則 jax.vmapjax.lax.scan 需要 PRNG 金鑰的清單。因此,對於 MLP 內部的 5 個層,您會分割並從其參數中提供 5 個不同的 PRNG 金鑰,然後再向下傳遞到 JAX 轉換。

    • 請注意,實際上 create_block() 知道它需要建立 5 層,正是因為 它看到 5 個 PRNG 金鑰,因為 in_axes=(0,) 表示 vmap 將會查看第一個參數的第一個維度,以了解它將要對其進行映射的大小。

    • 對於 forward() 也是如此,它會查看第一個參數(又名 model)內的變數,以找出它需要掃描多少次。nnx.split_rngs 在這裡實際上會分割 model 內的 PRNG 狀態。(如果 Block Module 沒有 dropout,則您不需要 nnx.split_rngs 行,因為它無論如何都不會消耗任何 PRNG 金鑰。)

  • 為什麼 Flax NNX 中的 Block Module 不需要接受和傳回額外的虛擬值: 這是 jax.lax.scan 的要求 (API 文件。Flax NNX 簡化了這一點,因此如果您設定 out_axes=nnx.Carry 而不是預設的 (nnx.Carry, 0),您現在可以選擇忽略第二個輸出。

    • 這是 Flax NNX 轉換與 JAX 轉換 API 不同的少數情況之一。

上面的 Flax NNX 範例中有更多程式碼行,但它們更精確地表達了每次發生的情況。由於 Flax NNX 轉換變得更接近 JAX 轉換 API,因此建議您在使用它們的 Flax NNX 對應物 之前,先對底層的 JAX 轉換 有很好的了解。

現在檢查兩邊的變數 pytree

# variables = model.init(key, x=jnp.ones((1, 64)), training=True)
# variables['params']
{
  ScanBlock_0: {
    Dense_0: {
      bias: (5, 64),
      kernel: (5, 64, 64),
    },
  },
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
  'blocks': {
    'linear': {
      'bias': VariableState(type=Param, value=(5, 64)),
      'kernel': VariableState(type=Param, value=(5, 64, 64))
    }
  }
})

在 Flax NNX 中使用 TrainState#

Flax Linen 有一個方便的 TrainState 資料類別,用於捆綁模型、參數和最佳化器。在 Flax NNX 中,這並不是真的有必要。在本節中,您將學習如何在任何向後相容性需求下,圍繞 TrainState 建構您的 Flax NNX 程式碼。

在 Flax NNX 中

  • 您必須先在模型上呼叫 nnx.split,以取得單獨的 nnx.GraphDefnnx.State 物件。

  • 您可以傳入 nnx.Param,以將所有可訓練的參數篩選到單一 nnx.State 中,並傳入 ... 作為剩餘的變數。

  • 您還需要子類化 TrainState,以為其他變數新增一個欄位。

  • 然後,您可以傳入 nnx.GraphDef.apply 作為 apply 函式、nnx.State 作為參數和其他變數,以及一個最佳化器作為 TrainState 建構子的參數。

請注意,nnx.GraphDef.apply 會將 nnx.State 物件作為參數傳入,並傳回一個可呼叫的函式。此函式可以在輸入上呼叫,以輸出模型的 logits,以及更新的 nnx.GraphDefnnx.State 物件。請注意,由於您沒有將 Flax NNX Modules 傳遞到 train_step 中,因此下方使用了 @jax.jit

from flax.training import train_state

sample_x = jnp.ones((1, 784))
model = nn.Dense(features=10)
params = model.init(jax.random.key(0), sample_x)['params']




state = train_state.TrainState.create(
  apply_fn=model.apply,
  params=params,

  tx=optax.adam(1e-3)
)

@jax.jit
def train_step(key, state, inputs, labels):
  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      inputs, # <== inputs
      rngs={'dropout': key}
    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(state.params)


  state = state.apply_gradients(grads=grads)

  return state
from flax.training import train_state

model = nnx.Linear(784, 10, rngs=nnx.Rngs(0))
model.train() # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)

class TrainState(train_state.TrainState):
  other_variables: nnx.State

state = TrainState.create(
  apply_fn=graphdef.apply,
  params=params,
  other_variables=other_variables,
  tx=optax.adam(1e-3)
)

@jax.jit
def train_step(state, inputs, labels):
  def loss_fn(params, other_variables):
    logits, (graphdef, new_state) = state.apply_fn(
      params,
      other_variables

    )(inputs) # <== inputs
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(state.params, state.other_variables)


  state = state.apply_gradients(grads=grads)

  return state