Flax 模組生命週期#
這個設計註釋是提供給已經熟悉 Flax Linen 模組,但希望能進一步了解抽像概念背後設計原則的使用者。本註釋應可讓您充分了解模組 API 奠基於哪些假設和保證。如果您目前對於模組毫無實際經驗,請查看 快速入門指南。
Flax Linen 模組在 Flax 核心之上提供一個 Python 風格的抽象概念。模組 抽象概念允許您在 JAX 之上建立具有狀態、參數與隨機性的類別。這是 Module
類別設計與行為的實務指南。在看完本指南後,您應該可以輕鬆跳脫常規,使用新方式應用模組。
概述#
定義#
首先來了解一下模組生命週期的概況。第一步,定義一個簡單的模組:
class MLP(nn.Module):
# 1. Attribute annotations
hidden_size: int
out_size: int
# 2. The ``setup`` method
def setup(self):
self.hidden = nn.Dense(self.hidden_size)
self.out = nn.Dense(self.out_size)
# 3. User methods
def __call__(self, x):
a = self.hidden(x)
h = nn.relu(a)
return self.out(h)
這個模組包含:
屬性註解,定義為 dataclass 欄位。這些註解會自動定義一個建構式。
``setup`` 方法,會建立子模組並將其指定給屬性。
使用者方法。根據慣例,大多數模組只有一個
__call__
方法,但您可以定義多個方法或使用不同的方法名稱。
建構/初始化#
現在我們要建構並使用 MLP
模組:
mlp = MLP(hidden_size=5, out_size=3)
x = jax.numpy.ones((1, 2))
variables = mlp.init(random.key(0), x)
y = mlp.apply(variables, x)
首先,我們建構一個 MLP
的實例,並傳遞建構屬性。請注意,如果您不習慣函式程式設計模式,這裡的建構方式和您預期的會不一樣。MLP
建構式並不會實際建立變數或任何內部狀態。最好將它視為包含功能的模組規格或範本,而不包含任何資料。
我們來仔細看看初始化。令人驚訝的是,Flax 中並沒有一個獨立的初始化路徑。init
只是 apply
的一個特例,您也可以將其寫成:
# equivalent to: variables = mlp.init(random.key(0), x)
_, variables = mlp.apply({}, x, rngs={"params": random.key(0)}, mutable=True)
因此,init
只是一個包裝在 apply
周圍的函式,其中:
我們呼叫一個沒有任何初始變數(一個空字典)的模組。
一個名為
"params"
的 PRNG 產生器會一律傳遞,用於參數隨機初始設定(使用參數初始設定函數)。所有變數集合都會設為可變(
mutable=True
)。當集合為可變時,現有變數可以更新,也可以建立新的變數。因此,在init
內部,變數可以在任何變數集合中初始化,且它們都會新增到傳回的變數字典。
生命週期#
既然您已了解到 init
是 apply
的特殊個案,讓我們來進一步了解 .apply(...)
。事實上,Modules 大部分的複雜性都存在於 apply
方法中。「Modules 生命週期」包括建立 Modules 和對其套用 apply
。我們可以將 Modules 生命週期總括如下:
我們建立
mlp = MLP(hidden_size=5, out_size=3)
,使得mlp.hidden_size=5
且mlp.out_size=3
。然後,呼叫
mlp.apply
,它會建立
mlp
的克隆,我們稱之為mlp_copy
。呼叫
mlp_copy.setup()
。傳回
mlp_copy.__call__()
的輸出結果,以及選擇性地傳回使用關鍵字引數mutable=
指定為可變的變數集合。
請注意,生命週期包括複製 Modules 執行個體。此舉旨在確保 apply
可被視為純函數(即如果您傳入相同的引數,它會傳回相同的輸出)。您稍後將在 頂層 Modules 段落中進一步了解這個部分。
變數#
「變數」這個詞在程式設計和數學中無所不在。不過,在 JAX 和 Flax 的背景下,了解變數是什麼是很重要的。在 Flax Modules 中,變數 的作用如您預期中的 Python。變數一經初始化就會讀取,甚至會不時更新。不過,JAX 沒有變數的概念。相反地,值會儲存在類似於 NumPy 陣列的陣列中,但有一個重要的差異,那就是它們是不可變的。
init
與apply
方法會回傳一個變數為巢狀字典,其字串為索引鍵,且 JAX 陣列在最底層。在最上層,每個索引鍵都對應到一個變數集合。在每個集合內,巢狀的字典結構與Module
層級架構相符。變數字典為不可變,因此它只是一個變數狀態的快照。當apply
再次被呼叫時,變數字典會作為一個參數傳遞。這時,變數處於與先前init
/apply
呼叫結束時相同的狀態。
請注意
模組欄位使用field_name: TypeHint語法宣告 (與 dataclasses 相同)。若沒有類型提示,某個屬性則被視為該類別的靜態屬性。如果您無法指定類型,您可使用typing.Any
作為萬用類型。
緊湊模組#
Linen 提供了另一種 API,可用於更緊湊地定義模組。在模組僅由使用參數和/或子模組的方法組成時,這特別有用。使用緊湊 API,MLP 可改寫成如下
class CompactMLP(nn.Module):
hidden_size: int
out_size: int
@nn.compact
def __call__(self, x):
a = nn.Dense(self.hidden_size)(x)
h = nn.relu(a)
return nn.Dense(self.out_size)(h)
緊湊Module
在精神上類似函數。它提供了簡潔的表示法,並將外部互動限制為函數的輸入和回傳值。在這裡,簡潔的表示法可能有助於他人理解模組的作用。您無需在setup
和__call__
方法之間來回跳動,即可理解子模組的作用。相反地,只需從上到下讀一次__call__
方法,即可提供簡潔的概觀。如果您實作具有許多超參數的複雜模組,這會產生顯著幫助。請參閱setup 或 compact,以取得關於如何決定 setup 或 compact 的實用指南。
在內聯定義子模組和/或變數的另一個好處是,您能在建構變數時,將參數新增到您的方法中。最常見的範例是使用形狀資訊來決定參數的形狀,例如這樣
class CompactScaledMLP(nn.Module):
hidden_size: int
out_size: int
@nn.compact
def __call__(self, x):
scale = self.param("scale", nn.initializers.ones_init(), x.shape[-1:])
x *= scale[None]
a = nn.Dense(self.hidden_size)(x)
h = nn.relu(a)
return nn.Dense(self.out_size)(h)
許多標準 Linen 模組,例如nn.Dense
,早已使用形狀推論,以避免需要指定輸入形狀 (例如 Dense 層的輸入特徵數)。
緊湊控制流程#
如果你未明確提供子模組的名稱(使用傳遞給模組建構函式的 name=
關鍵字參數),定義子模組的順序會決定子模組的名稱。由於 name
決定參數會如何對應到子模組,你必須小心混合控制流程和自動產生的名稱。使用控制流程可能會變更順序或完全移除某些子模組。如果某個子模組應該僅根據某些建構參數而存在,這很有用。然而,當控制流程依賴於模組的輸入參數時,你應該小心。例如,下面這個模組會中斷
class WrongModule(nn.Module):
@nn.compact
def __call__(self, x, mode):
if mode == "encode":
return nn.Dense(features=8)(x)
elif mode == "decode":
return nn.Dense(features=4)(x)
上面的模組會中斷,因為編碼器或解碼器路徑將會建立稱為「Dense_0」的模組。這表示兩個模組將會共用參數,這是這裡未預期的。實際上,這兩個模組無法共用參數,因為它們各自有不同的特色。
- 這個問題可以使用多種方法解決
提供明確名稱
在
setup
中建立模組或將建構函式移出控制流程外。後者如下所示
class CorrectModule(nn.Module):
@nn.compact
def __call__(self, x, mode):
encoder = nn.Dense(8)
decoder = nn.Dense(4)
if mode == "encode":
return encoder(x)
elif mode == "decode":
return decoder(x)
在上面的範例中,建構順序是固定的。建構完成後,子模組可以依任意順序使用。
請注意
精簡模組十分類似於 React 掛鉤。
頂層模組#
當一個模組實例是在「頂層」建立時,它會處於「未綁定」狀態,亦即,沒有任何變數附加在上面。「頂層」是指它並非用另一個模組類別內的子模組所建構。除了呼叫 init
和 apply
以外,你無法對一個未綁定的模組做太多事。也要注意, setup
未在未綁定的模組上呼叫,所以你只能存取建構參數。參閱 後續工作 區段,了解未來這個部分可能會如何變更。
為什麼頂層模組總是未綁定的?#
當我們呼叫 apply
時,頂層模組的一個複本會被建立,這個複本實際上會包含變數和 PRNG 序列。這個有狀態、「已綁定」的複製只存在於我們執行套用方法時。原因是如果你建立一個有狀態物件,並且在套用函數回傳前就將之銷毀,則 apply
函數本身就像一個純粹函數一樣。一個純粹函數有兩個約束
如果你放入相同的參數,它會回傳相同的輸出
它不會變更函式外的任何內容。這表示您無法處理可從純函式外部存取的狀態物件。
純函式有很多好處,但使用 JAX 時,它們通常是必不可少的。例如,大多數程式碼需要使用 jax.jit
編譯才能快速運行,一旦建立了模組,您可能會想使用 jax.grad
最佳化其參數。然而,這些 API 會預期純函式,且無法直接對有狀態繫結的 Module
執行個體。此外,純函式允許與其他函式庫靈活地相互操作。例如,我們建議 Optax 來最佳化參數。Optax 中的最佳化程式預期並傳回 JAX 陣列的 PyTree 以進行最佳化,就像 Linen 模組的 apply
函式一樣。
複製#
若要讓此方法穩定地運作,我們需要明確定義的複製行為。Flax 不是仰賴像 Python 的 deepcopy
那樣的複雜巢狀複製程序,而是強制 Module
完全由其建構引數定義。因此,複製模組會簡化為使用其原始建構引數來呼叫建構函式。由於 Module
會充當不可變資料類別,因此建構引數會直接對應到執行個體屬性。在 setup
或 __post_init__
計算出來的非建構屬性,也應該僅依賴建構引數,才能確保複製是明確定義的。
設定#
在一般 Python 類別中,setup
方法通常使用類似於建構函數掛鉤(__init__
)的方式。然而,對於較進階的使用案例來說,最好了解它與建構函數不同。
setup
僅在模組繫結後呼叫。一般來說,這並非問題,因為大部分模組幾乎會立即繫結(作為 init
和 apply
的一部分)。在 setup
裡面,次模組會在指派給屬性時繫結。在用 nn.compact
修飾的方法內部,次模組會在建構時立即繫結。如同在先前章節中所述,頂層模組永遠不會繫結,因此在建構時不會呼叫 setup。這表示您無法從未繫結的頂層模組存取在 setup 中指派好的屬性。
class TopLevelAccess(nn.Module):
def setup(self):
self.foo = nn.Dense(2)
mdl = TopLevelAccess()
assert not hasattr(mdl, "foo") # foo is not defined because setup is not called
setup
方法並非在 Module
繫結後立即呼叫,而是在您與 Module
執行個體進行互動(例如:呼叫方法、存取屬性)時才會呼叫。這不應該影響 Module
的行為,只是延遲執行有時會影響除錯期間的紀錄陳述和堆疊追蹤。有關 functionalization 的章節將說明為什麼我們需要 setup
延遲執行的主要原因。
函式化#
到目前為止,我們有純粹的 apply
函式,通常會以一些 JAX 轉換做變形,而在 apply
內部我們有一個有狀態模組執行個體可以搭配使用。換句話說,在模組外部我們處於函式世界,我們可以使用 JAX 的函式轉換,而在模組內部我們可以用 Flax 的有狀態變數和 PRNG 序列,而且 apply
方法是兩個世界之間的橋樑。
但是,如果我們想要在模組內部使用 JAX 轉換怎麼辦?函式化就是答案。
這個程序本身又繁瑣,而且容易出錯,但 Flax 在內部會處理。高層次來說,我們可以將它歸納如下。對於模組內部所定義的方法 fn
收集應該在 JAX 轉換內部可用的模組(變數和 PRNG 順序),並截取它的快照。
使用原始引數和收集的狀態呼叫 JAX 轉換。然後在轉換的內部
解開狀態並重建模組
呼叫使用者程式碼
fn
收集更新的變數和 rng 並將其與來自
fn
的原始回傳值一起回傳
使用來自轉換回傳的更新狀態更新原始狀態。
可以在 提昇轉換設計備註中找到有關函數化和提昇的更深入說明。
實用後果#
在大部分情況下,函數化是由系統自動處理。不過還是有一些你必須考慮的限制。最重要的是,Flax 只處理有狀態的原語(Linen 變數和 RNG),而不處理任意的有狀態 Python 程式碼。最重要的是:你無法封閉有狀態的物件和 Module
物件,因為它們對 Flax 的內部(與一般來說的 JAX)是不可見的。
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
dense = nn.Dense(x.shape[-1])
fn = lambda x: dense(x) + 1
# simply calling inner works fine
# return self.inner(x, fn)
# but applying a transformation doesn't:
vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True})
return vmap_inner(self, x, fn)
def inner(self, x, fn):
for i in range(3):
x = fn(x)
return x
在此,inner
需要一個封閉模組實例的函式。這個範例中,這樣做沒有問題,因為我們並未以提昇轉換轉換內部方法。大多數方法都不會轉換,但知道如何使模組方法變為可轉換的會很有用。
無法轉換的主要障礙是 JAX 辨識不出的類型。JAX 只理解 Pytree 引數;也就是任意的嵌套 Python 容器(dict、清單、tuple)包含(Jax)numpy ndarrays 以及 Python 數字/布林值。Flax 允許使用 flax.struct API 定義與 Pytree 相容的資料類別。
函式封閉是用 JAX 陣列或 Linen 模組意外隱藏轉換的最常見方式。不過,如果你想傳遞與 JAX 和 Linen 轉換相容的封閉,有一個簡單的解決方法
class Partial(flax.struct.PyTreeNode):
fn: Callable = flax.struct.field(pytree_node=False)
args: Iterable[Any]
def __call__(self, *args, **kwargs):
return self.fn(*(tuple(self.args) + args), **kwargs)
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
dense = nn.Dense(x.shape[-1])
fn = lambda mdl, x: mdl(x) + 1
vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True})
return vmap_inner(self, x, Partial(fn, [dense]))
def inner(self, x, fn):
for i in range(3):
x = fn(x)
return x
在此,封閉使用 Flax 資料類別實作。函式本身註解為 flax.struct.field(pytree_node=False)
,以指出它不包含 JAX 陣列或 Linen 模組。相反地,部分套用的 args
被當成 pytree 容器處理。我們重新編寫封閉來使用 Partial。現在可以使用提昇轉換轉換內部方法了。
後續工作#
未繫結模組設定#
建立後要初始化欄位時,目前的模組抽象特別嚴格。在目前的模組 API 中,setup
方法是初始化模組實體欄位的地方。因為 setup
僅在已繫結的模組上呼叫,所以完整的模組 API 可於 setup
內使用,包含宣告變數。但是,我們通常不需要任何有狀態 API 來初始化欄位。事實上,我們最常單純宣告一個子模組。更重要的是,通常會檢查子模組以進行除錯,或是部分執行模型。例如
class AutoEncoder(nn.Module):
def setup(self):
self.encoder = Encoder(...)
self.decoder = Decoder(...)
想像一下,我們希望使用 auto_encoder.decoder.apply(decoder_variables, x) 來呼叫解碼器。透過目前的設定 API,這個做法無法運作,因為我們必須先繫結變數,才會呼叫設定,且解碼器屬性已定義。我們當然可以使用與設定中相同的屬性,來手動建構解碼器模組,但這在許多情況下並非理想作法。
有兩種可能的解決方案,可讓此使用案例更符合人體工學。第一,設定可在建立後立即執行,並在繫結前生效。這表示你仍然可以建立子模組,但你無法再定義或處理變數。因此,這將會是一個重大變更,而且需要新的 API 才能延遲定義變數
或者,可以引入另一個特殊方法,在模組建立後立即執行,且在它繫結之前就生效。這種情況下,setup
方法將保留其原始語義。