setup vs compact#

在 Flax 的模組系統當中(稱為 Linen),可以透過下列兩種方式來定義子模組和變數(參數或其他變數)

  1. 明確(使用 setup

    setup 方法中,將子模組或變數指派給 self.<attr>。然後,在類別中定義的任何「前向傳遞」方法中,使用指派給 self.<attr> 的子模組和變數。這類似於 PyTorch 中定義模組的方式。

  2. 共置(使用 nn.compact

    直接在用 nn.compact 加註的單一「前向傳遞」方法中撰寫網路的邏輯。這允許您在單一方法中定義整個模組,而且可以將子模組和變數「配置」在用到的位置旁邊。

這兩種方法都是完全有效的,行為相同,而且可以與所有 Flax 互通使用。.

以下是一個模組透過兩種方式定義的簡短範例,而且具有完全相同的功能。

class MLP(nn.Module):
  def setup(self):
    # Submodule names are derived by the attributes you assign to. In this
    # case, "dense1" and "dense2". This follows the logic in PyTorch.
    self.dense1 = nn.Dense(32)
    self.dense2 = nn.Dense(32)

  def __call__(self, x):
    x = self.dense1(x)
    x = nn.relu(x)
    x = self.dense2(x)
    return x
class MLP(nn.Module):

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(32, name="dense1")(x)
    x = nn.relu(x)
    x = nn.Dense(32, name="dense2")(x)
    return x

因此,您會如何決定要使用哪種樣式?這可能是品味問題,但以下是優缺點

較喜歡使用 nn.compact 的理由:#

  1. 允許在需要用到的位置旁邊定義子模組、參數和其他變數:較少需要向上/向下捲動螢幕來檢視所有內容是如何定義的。

  2. 當條件或迴圈以條件式定義子模組、參數或變數時,可減少重複的程式碼。

  3. 程式碼通常更類似於數學符號:y = self.param('W', ...) @ x + self.param('b', ...) 看起來類似於 \(y=Wx+b\))

  4. 如果您使用形狀推論,亦即使用其形狀/值依賴於輸入形狀(在初始化時未知)的參數,那麼無法使用 setup

偏好使用 setup 的原因:#

  1. 更接近 PyTorch 慣例,因此在從 PyTorch 移植範例時更容易

  2. 有些人認為明確區分副模組的定義和變數與使用它們的地方更自然

  3. 允許定義多個「前向傳遞」方法(參閱 MultipleMethodsCompactError