提升的轉換#

⚠️ 進階主題 ⚠️

此設計備忘會說明 flax.linen.transform 的底層實作,這會啟用 flax.linen.Module 中的 JAX 轉換。

引言#

JAX 使用函數式 API,表示在沒有副作用的情況下使用函數才能保證正確的行為 (JAX 文件)。一般來說,這些副作用是由改變函數外物件導致的。

函數式範例有一些優點,例如明確推論狀態與隨機性的能力。函數輸出僅在輸入引數變更時才會變更。因此,函數保證會以確定性的方式運作。

不過,純粹函數提供另一個對 JAX 的極大優點:確切來說,它們啟用函數轉換。例如 jax.vmap(f) 將向量化函數 f。因為 f 無法產生副作用,因此定義完整的 f 的向量化/平行版本。要找出為何需要此限制,請考量如果 f 會增加計數器或繪製亂數時會發生什麼事。對於向量中的每一項目,f 會為每一個項目繪製相同或不同的亂數嗎?批次中的每一個項目會有自己的計數器,或是所有項目共用計數器?如果以平行方式計算 f 時,會以什麼順序增加計數器?所有這些問題的答案都是「視情況而定」。行為模稜兩可,而函數限制優雅地避免此問題。

Flax 提出一個安全的方式,在相容於 JAX 的格式中提供有限的隨機性與狀態變數。Flax 中的狀態不會造成問題的原因,是因為它是本地的:Flax Module 內部有變數和 PRNG 序列,但外部只有 JAX 陣列和 PRNG 金鑰。

在大部分的用例中,Flax 用於以有狀態的方式定義模型。因為 Module 在外部表現得像純粹函數,因此我們可以使用所有轉換來充分利用 JAX。不過,有時候我們會想要透過使用轉換和 Module 並重,同時享受兩個世界的優點。此設計說明解釋我們如何延伸 JAX 的函數轉換到在包含內部狀態與隨機性的 Module 上運作。

函數化#

在深入細節之前,讓我們考量一個簡單的範例,在範例中我們希望在 Module 內部使用 vmap

首先,我們定義一個沒有任何轉換的簡單 MLP

import jax
from jax import random, numpy as jnp
from flax import linen as nn

class MLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    h = nn.Dense(4, name='hidden')(xs)
    h = nn.relu(h)
    return nn.Dense(1, name='out')(h)

現在,如果我們想要為 xs 中的每一項有獨立的 MLP 參數呢?如果這是「香草 JAX」,我們可以想像寫一些像是 jax.vmap(apply_mlp)(mlp_params, xs) 的東西。但是,在 Linen 中做這樣的事情實際上會失敗

class NaiveVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    mlp = MLP()
    return jax.vmap(lambda mlp, x: mlp(x))(mlp, xs)  # fails

vmap 用於 mlp 時,JAX 會引發一個錯誤,因為它不是 JAX 陣列或陣列的簡單容器。我們不能真的責怪 JAX 拒絕執行這項未指定的工作。畢竟,甚至不清楚這裡會發生什麼。MLP 內部的參數甚至尚未初始化,而我們需要為每一組參數一個單獨的 PRNG 鍵。 jax.vmap 只能廣播或映射到一個軸,但它不能自動分割一個 PRNG 鍵。因此,我們必須手動呼叫 jax.random.split

我們可以通過先將 MLP 變成一個純初始化和應用函數來修復此問題。之後,我們使用 param 方法來儲存參數

class ManualVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    mlp = MLP(parent=None)
    init_fn = lambda rng, xs: jax.vmap(mlp.init, in_axes=0)(random.split(rng, xs.shape[0]), xs)['params']
    apply_fn = jax.vmap(mlp.apply, in_axes=0)
    mlp_params = self.param('mlp', init_fn, xs)
    return apply_fn({'params': mlp_params}, xs)

xs = jnp.ones((3, 4))
variables = ManualVmapMLP().init(random.key(0), xs)
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
    mlp: {
        hidden: {
            bias: (3, 4),
            kernel: (3, 4, 4),
        },
        out: {
            bias: (3, 1),
            kernel: (3, 4, 1),
        },
    },
}
"""

這裡,MLP(parent=None) 建立一個 MLP 的分離實例。這避免為目前模組內的子模組保留名稱。儘管不是絕對必要,這也確保我們無法意外地以有狀態的方式使用 MLP 實例,而我們被迫透過 .init.apply 來使用它。

這個範例仍然相對簡潔,但已經需要一些額外的「記錄」陳述才能使其運作。但是,此實作有許多限制

  1. 在初始化期間,我們透過 init_fnapply_fn 兩次呼叫子模組。如果子模組使用相同的技巧來執行函式轉換,我們將會執行大量的程式碼,因為模組呼叫次數隨著 2^d 的方式增加,其中 d 是嵌套函式轉換的數量。

  2. 此實作假設子模組只需要參數 RNG 順序。

  3. 此實作假設我們僅在 init 期間在「params」集合中建立變數。但是,它不支援其他變數集合,也不支援在 apply 中建立/更新變數。

特別是第 3 點讓手動功能化變得繁瑣。請隨時嘗試使用 nn.BatchNorm 層將上述範例延伸至 MLP 模組。這需要處理一些額外的複雜性,例如儲存已更新的批次統計資料,並確保批次統計資料在應為不可變(例如:評估模式)的 vmap 中不可變。

我們稱轉換有狀態模組為純函式的過程為「功能化」。透過暫時將有狀態 Module 轉變為函式,我們使它相容於 JAX 的函式轉換。

提升#

Flax 提供手動功能化的替代方案,我們稱其為提升轉換。提升轉換定義在 flax.core.lift。所有 JAX 提升轉換都使用稱為 pack 的單一通用提升 API 定義。

必須做出一些決定才能定義 packpack 的實作控制變數和 rng 如何提升,以及使用者控制有多精細。它還必須決定提升決策是在變數或轉換定義時做出。

提升細緻度#

使用 Linen API,使用者可以定義任意變數集合和 PRNG 序列。集合中的每個變數都會以相同方式提升。

集合通常會給予語義上有意義的名稱,例如「params」或「batch_stats」,而不是通用名稱,例如「state」。因為集合具有語義意義,我們可以在轉換層級決定如何提升每個集合。例如,當我們將批次維度新增至模型時,我們想要共用所有參數變數。

同時,我們可以撰寫通用程式碼,使用轉換時不會確切知道子模組會建立哪種變數。因此,集合在細緻控制與普遍性之間取得平衡。我們也避免會因不穩定的字串比對程式碼在所有變數中尋找並嘗試基於命名慣例(例如:針對以「kernel」為字首的名稱鎖定所有變數)以不當的方式分割集合。如果需要更精細的控制,使用者可以輕鬆地將一組變數分割至應以不同方式處理的多個集合中。

轉換與變數控制#

Lifting 動作可以在轉換層級或變數定義時定義。我們使用轉換層級定義的 lifting 動作。選擇此方式的原因是許多不同的轉換具有各種不同的動作。例如:vmap 有廣播和向量化引數,而 scan 有 scan、carry 和廣播引數。否則變數必須為所有這些轉換定義其動作,否則 Module 將不與這些轉換相容。或者,我們必須為如何處理轉換做出預設決策。但是,這可能會導致靜態錯誤,因為動作可能在使用者意圖下實際上無效。

lift 套件還提供一個通用 transform,允許任意函數轉換變數集合。例如,這可用於透過轉置權重繫結連結式自編碼器中的權重。如果在變數定義時做出 lifting 決策,則不清楚是否可以定義類似的通用轉換。

Linen#

lifting 模組不了解 Linen Module API。相反,它直接在 flax.core.Scope 實例上執行操作。Scope 實例包含 Module 變數和 PRNG 序列。每個 Module 實例在 .scope 欄位中有 Scope 實例,如果它具有父系或它是使用 initapply 建立的。一般而言,頂層 Module 實例 (您呼叫 initapply 的實例) 是唯一沒有 Scope 繫結到它的 Module 實例。

當一個 Module 被轉換時,我們使用 flax.core.lift API 來 lifting 範圍並使用 Module.clone() 建立新的 Module 實例,將 lifting 範圍繫結到它。

flax.linen.transforms 公開 flax.core.lift 中變換的包裝器。核心提升 API 在函式上運作,而 Linen 包裝器可以轉換 Module 類別或 Module 方法。

因此,提升會獨立於 Linen API 實作。這種分工簡化了實作,並可能允許其他 Module 抽象化建立在用於提升和狀態管理的共同核心。

實作#

pack(fn, in_vars, out_vars, rngs) API 經過以下階段

  1. 範圍去重複

    此階段僅在同時提升多個範圍時相關。在這種情況下,我們必須首先找到根範圍集合。如果範圍的祖先都不在需要提升的範圍集合中,則該範圍就是根範圍。

    只提升根範圍,就可以避免提升同一個變數兩次。

    對於非根範圍,我們會儲存到其祖先範圍的參考和一條路徑,以便我們稍後可以重建它(第 4 階段)。

  2. 過濾階段

    將變數和 PRNG 序列分成群組。這樣,fn 可以將每個群組分別提升為轉換。群組定義為如下指定的過濾器

    • 集合/prng 名稱清單

    • True(比對全部)

    • False(不比對任何內容)

    • DenyList(filter)(比對除指定集合外的全部(例如:DenyList(['params']) 比對除「params」集合以外的全部)。

    集合或 PRNG 序列只能放入一個群組。如果集合比對到多個過濾器,它將放入具有相符過濾器的第一個群組中。如果集合或 PRNG 序列未比對到任何過濾器,則不會提升它。這表示它無法在轉換中使用,如果嘗試這樣做,會產生錯誤。舉例來說,in_vars = (["params"], True) 會將「params」集合放入第一個群組,而將所有其他集合放入第二個群組。

    對於比對到每個 PRNG 序列,我們會呼叫 make_rng 來植入新的 PRNG 序列。這能避免在提升轉換完成後更新 PRNG 狀態。

  3. 特定轉換提升

    fn會呼叫變數和 PRNG 群組。JAX 轉換具有不同的簽章和提升選項。最簡潔的範例應該是 vmap。在 vmap 的情況下,函式參數、PRNG 和變數集合會傳入 jax.vmap 封裝函式。

  4. 範圍重建

    現在變數和 PRNG 已在轉換內提升,我們要重新建立提升範圍。Pack 會呼叫 fn,並使用 scope_fn,此函式會取得提升後的變數和 PRNG,並傳回經重建的範圍,其中包含提升後的變數和 rng 順序。

  5. 重新打包階段

    在我們使用提升範圍之後,必須擷取更新後的變數(PRNG 順序可以自行捨棄)。pack 會傳入 repack_fn 來提供支援。此階段類似於階段 2,不過我們只會提升變數,而忽略不可變變數。不可變變數無法更新。因此,它們不應該從已轉換的函式中傳回。

  6. 提交階段

    pack 預計 fn 會傳回成對資訊,其中第一個項目會只從 pack 傳回,而第二個項目應為重新打包的變數。更新後的變數會儲存在原始/未提升的範圍中,以讓轉換內產生的變更可以在轉換完成後持續保留。

pack 使用範例#

使用 pack 將變數集合中每個矩陣轉置的精簡範例

from flax.core import lift
from flax.core import Scope, init, apply, nn as core_nn

def lift_transpose(fn, target='params', variables=True, rngs=True):
  # by default we transpose 'params' and simply pass through all other variables.
  def wrapper(scope_fn, repack_fn, variable_groups, rng_groups, *args):
    # normally we would first call into a JAX transformed function here...
    target, rest = variable_groups
    def trans(x):
      if x.ndim == 2:
        return x.T
      return x
    target = jax.tree_util.tree_map(trans, target)
    variable_groups = (target, rest)
    scope = scope_fn(variable_groups, rng_groups)
    y = fn(scope, *args)
    out_variables = repack_fn(scope)
    return y, out_variables
  return lift.pack(
      wrapper,
      in_variable_filters=(target, variables),
      out_variable_filters=(variables,),
      rng_filters=(rngs,))

x = jnp.ones((3, 2))
y, params = init(lift_transpose(core_nn.dense))(random.key(0), x, 4)

請注意,大多數使用者不應與 pack 直接互動。如果您發現現有的提升轉換尚未支援的使用範例,請開啟 GitHub 問題回報。

支援的轉換#

Jax 轉換

亞麻布支援嗎?

留言

vmap

scan

Carry 變數無法在掃描主體內初始化。

remat

jit

目前的實作可能會導致不必要的重新編譯。

jvp

vjp

custom_vjp

custom_jvp

while_loop

Carry 變數無法在 while_loop 主體內初始化。

cond

變數初始/變更必須在各分支之間進行結構性比對。

switch

變數初始/變更必須在各分支之間進行結構性比對。

pmap

xmap

參考資料

亞麻布範例#

回到我們的原始範例,我們現在可以使用 nn.vmap 進一步簡化我們的應用程式實作

class LinenVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    VmapMLP = nn.vmap(MLP, variable_axes={'params': 0}, split_rngs={'params': True}, in_axes=0)
    return VmapMLP(name='mlp')(xs)

variables = LinenVmapMLP().init(random.key(0), xs)
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
    mlp: {
        Dense_0: {
            bias: (3, 4),
            kernel: (3, 2, 4),
        },
        Dense_1: {
            bias: (3, 1),
            kernel: (3, 4, 1),
        },
    },
}
"""

這裡我們使用 variable_axes={'params': 0} 來標示參數應該是向量化而不是共用,而 split_rngs={'params': True} 則表示每組參數都是獨立初始化的。

我們也可以透過新增一個 BatchNorm 層,使用一些內部狀態來擴充這個範例

class StatefulMLP(nn.Module):
  @nn.compact
  def __call__(self, x, *, train):
    h = nn.Dense(4, name='hidden')(x)
    h = nn.BatchNorm(axis_name='batch')(h, use_running_average=not train)
    h = nn.relu(h)
    return nn.Dense(1, name='out')(h)

class LinenStatefulVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs, *, train):
    VmapMLP = nn.vmap(StatefulMLP, variable_axes={'params': 0, 'batch_stats': 0}, split_rngs={'params': True}, in_axes=0)
    return VmapMLP(name='mlp')(xs, train=train)
variables = LinenStatefulVmapMLP().init(random.key(0), xs)

我們只需要在 nn.vmap 中新增 'batch_stats': 0,表示區塊統計資料應該是向量化而不是在第一個軸線上共用。

替代方案#

其他數值運算架構會將變數視為一級公民。取代函式化的方法是使用一個變數系統,這個系統可能是整合在 JAX 中或是建構在 JAX 的基礎上。這種方法的優點是讓每個變數的提升更容易。如果變數是 JAX IR (JAXPR) 中的一部分,我們可以檢查在某些運算中哪些變數必須提升。選擇性地,它們可以附上一個集合標籤,以設定各種提升選項。

採用這種方法的缺點是,變數系統較為複雜。變數是相關的參照,並且會破壞函式式程式設計的核心假設(請參閱 參照透明度),而目前具有函式介面的其他 API 可能也需要整合(例如:檢查點與最佳化 API)。