處理 Flax 模組參數

處理 Flax 模組參數#

緒論#

在 Flax Linen 中,我們可以使用資料類別屬性或方法引數來定義 Module 引數(通常為 __call__)。這種區別通常很明確

  • 完全固定的屬性,例如核初始化選項或輸出特徵數量,是超參數,應定義為資料類別屬性。通常具有不同超參數的兩個 Module 實例無法以有意義的方式共享。

  • 動態屬性(例如輸入資料和頂層「模式切換」,如 train=True/False)應作為引數傳遞給 __call__ 或其他方法。

不過,有些情況則較不分明。以 Dropout 模組為例,我們有多個明確的超參數

  1. 中斷率

  2. 產生中斷遮罩的軸

以及一些明確的呼叫時間引數

  1. 應使用中斷遮罩的輸入

  2. 用於抽樣隨機遮罩的(可選)rng

不過,有一個屬性存在歧義,即 Dropout 模組中的 deterministic 屬性。

如果 deterministicTrue,則不會抽樣中斷遮罩。這通常用於模型評估。但是,如果我們將 eval=Truetrain=False 傳遞給頂層 Module。則必須在所有位置套用 deterministic 引數,且布林引數需要傳遞給所有可能會使用 Dropout 的層。如果 deterministic 是資料類別屬性,我們可能會執行下列動作

from functools import partial
from flax import linen as nn

class ResidualModel(nn.Module):
  drop_rate: float

  @nn.compact
  def __call__(self, x, *, train):
    dropout = partial(nn.Dropout, rate=self.drop_rate, deterministic=not train)
    for i in range(10):
      x += ResidualBlock(dropout=dropout, ...)(x)

在此將 determinstic 傳遞給建構式很有道理,因為這麼一來,我們可以將中斷範本傳遞給子模組。現在,子模組不再需要處理 train 與 eval 模式,而只需使用 dropout 引數即可。請注意,由於中斷層只能在子模組中建構,因此我們只能將 deterministic 部分套用於建構式,而無法套用於 __call__

但是,如果 deterministic 是資料類別屬性,當使用設定模式時,我們會遇到問題。我們希望像這樣撰寫我們的模組程式碼

class SomeModule(nn.Module):
  drop_rate: float

  def setup(self):
    self.dropout = nn.Dropout(rate=self.drop_rate)

  @nn.compact
  def __call__(self, x, *, train):
    # ...
    x = self.dropout(x, deterministic=not train)
    # ...

但是,如以上定義,deterministic 將會是一個屬性,所以這無法運作。在此情況下,傳遞 deterministic__call__ 中是有意義的,因為這取決於 train 參數。

解決方案#

我們可以支援之前所描述的兩種使用案例,透過允許特定屬性做為 dataclass 屬性或方法參數傳遞(但不能同時傳遞)。這可以用以下方式實作

class MyDropout(nn.Module):
  drop_rate: float
  deterministic: Optional[bool] = None

  @nn.compact
  def __call__(self, x, deterministic=None):
    deterministic = nn.merge_param('deterministic', self.deterministic, deterministic)
    # ...

在這個範例中,nn.merge_param 會確保 self.deterministicdeterministic 設定其中一個,但不能同時設定。如果兩個值均為 None 或均不為 None,則會回傳錯誤。這可以避免令人混淆的行為,也就是程式碼中的兩個不同部分設定相同的參數,且一個被另一個覆寫。它也可以避免預設值,這可能會造成訓練程序訓練步驟或評估步驟在預設情況下中斷。

函式核心中#

函式核心中定義函式,而非類別。因此,超參數和呼叫時間參數之間沒有明確的區別。預先決定超參數的唯一方法是使用 partial。好處是沒有方法參數也可以成為屬性的模糊案例。