處理 Flax 模組參數#
緒論#
在 Flax Linen 中,我們可以使用資料類別屬性或方法引數來定義 Module
引數(通常為 __call__
)。這種區別通常很明確
完全固定的屬性,例如核初始化選項或輸出特徵數量,是超參數,應定義為資料類別屬性。通常具有不同超參數的兩個 Module 實例無法以有意義的方式共享。
動態屬性(例如輸入資料和頂層「模式切換」,如
train=True/False
)應作為引數傳遞給__call__
或其他方法。
不過,有些情況則較不分明。以 Dropout
模組為例,我們有多個明確的超參數
中斷率
產生中斷遮罩的軸
以及一些明確的呼叫時間引數
應使用中斷遮罩的輸入
用於抽樣隨機遮罩的(可選)rng
不過,有一個屬性存在歧義,即 Dropout 模組中的 deterministic
屬性。
如果 deterministic
為 True
,則不會抽樣中斷遮罩。這通常用於模型評估。但是,如果我們將 eval=True
或 train=False
傳遞給頂層 Module。則必須在所有位置套用 deterministic
引數,且布林引數需要傳遞給所有可能會使用 Dropout
的層。如果 deterministic
是資料類別屬性,我們可能會執行下列動作
from functools import partial
from flax import linen as nn
class ResidualModel(nn.Module):
drop_rate: float
@nn.compact
def __call__(self, x, *, train):
dropout = partial(nn.Dropout, rate=self.drop_rate, deterministic=not train)
for i in range(10):
x += ResidualBlock(dropout=dropout, ...)(x)
在此將 determinstic
傳遞給建構式很有道理,因為這麼一來,我們可以將中斷範本傳遞給子模組。現在,子模組不再需要處理 train 與 eval 模式,而只需使用 dropout
引數即可。請注意,由於中斷層只能在子模組中建構,因此我們只能將 deterministic
部分套用於建構式,而無法套用於 __call__
。
但是,如果 deterministic
是資料類別屬性,當使用設定模式時,我們會遇到問題。我們希望像這樣撰寫我們的模組程式碼
class SomeModule(nn.Module):
drop_rate: float
def setup(self):
self.dropout = nn.Dropout(rate=self.drop_rate)
@nn.compact
def __call__(self, x, *, train):
# ...
x = self.dropout(x, deterministic=not train)
# ...
但是,如以上定義,deterministic
將會是一個屬性,所以這無法運作。在此情況下,傳遞 deterministic
於 __call__
中是有意義的,因為這取決於 train
參數。
解決方案#
我們可以支援之前所描述的兩種使用案例,透過允許特定屬性做為 dataclass 屬性或方法參數傳遞(但不能同時傳遞)。這可以用以下方式實作
class MyDropout(nn.Module):
drop_rate: float
deterministic: Optional[bool] = None
@nn.compact
def __call__(self, x, deterministic=None):
deterministic = nn.merge_param('deterministic', self.deterministic, deterministic)
# ...
在這個範例中,nn.merge_param
會確保 self.deterministic
或 deterministic
設定其中一個,但不能同時設定。如果兩個值均為 None
或均不為 None
,則會回傳錯誤。這可以避免令人混淆的行為,也就是程式碼中的兩個不同部分設定相同的參數,且一個被另一個覆寫。它也可以避免預設值,這可能會造成訓練程序訓練步驟或評估步驟在預設情況下中斷。
函式核心中#
函式核心中定義函式,而非類別。因此,超參數和呼叫時間參數之間沒有明確的區別。預先決定超參數的唯一方法是使用 partial
。好處是沒有方法參數也可以成為屬性的模糊案例。