使用篩選器#

注意:此頁面與新的 Flax NNX API 相關。

篩選器在 Flax NNX 中被廣泛用作在諸如 nnx.splitnnx.state 和許多 Flax NNX 轉換等 API 中建立 State 群組的方法。 例如

from flax import nnx

class Foo(nnx.Module):
  def __init__(self):
    self.a = nnx.Param(0)
    self.b = nnx.BatchStat(True)

foo = Foo()

graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)

print(f'{params = }')
print(f'{batch_stats = }')
params = State({
  'a': VariableState(
    type=Param,
    value=0
  )
})
batch_stats = State({
  'b': VariableState(
    type=BatchStat,
    value=True
  )
})

這裡,nnx.Paramnnx.BatchStat 被用作篩選器,將模型分成兩組:一組包含參數,另一組包含批次統計資訊。然而,這引發了以下問題

  • 什麼是篩選器?

  • 為什麼像 ParamBatchStat 這樣的類型是篩選器?

  • 如何群組/篩選 State

篩選器協定#

一般來說,篩選器是以下形式的謂詞函數


(path: tuple[Key, ...], value: Any) -> bool

其中 Key 是一種可雜湊和可比較的類型,path 是一個由 Key 組成的元組,表示巢狀結構中值的路徑,而 value 是路徑上的值。如果該值應包含在群組中,則該函數返回 True,否則返回 False

類型顯然不是這種形式的函數,因此它們被視為篩選器的原因是,正如我們接下來將看到的,類型和一些其他文字會轉換為謂詞。例如,Param 大致轉換為類似這樣的謂詞

def is_param(path, value) -> bool:
  return isinstance(value, nnx.Param) or (
    hasattr(value, 'type') and issubclass(value.type, nnx.Param)
  )

print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True

此函數會匹配任何屬於 Param 實例的值,或具有 type 屬性並且該屬性是 Param 子類的值。在內部,Flax NNX 使用 OfType,它為給定類型定義了這種形式的可呼叫物件

is_param = nnx.OfType(nnx.Param)

print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True

篩選器 DSL#

為了避免使用者必須建立這些函數,Flax NNX 公開了一個小的 DSL,以 nnx.filterlib.Filter 類型形式化,讓使用者可以傳遞類型、布林值、省略符號、元組/列表等,並在內部將它們轉換為適當的謂詞。

以下是 Flax NNX 中包含的所有可呼叫篩選器及其 DSL 文字(如果可用)的清單

文字

可呼叫

描述

...True

Everything()

匹配所有值

NoneFalse

Nothing()

不匹配任何值

類型

OfType(type)

匹配屬於 type 實例的值,或具有 type 屬性且該屬性是 type 實例的值

PathContains(key)

匹配具有包含給定 key 的關聯 path 的值

'{filter}' str

WithTag('{filter}')

匹配具有與 '{filter}' 相等的字串 tag 屬性的值。由 RngKeyRngCount 使用。

(*filters) tuple[*filters] list

Any(*filters)

匹配符合任何內部 filters 的值

All(*filters)

匹配符合所有內部 filters 的值

Not(filter)

匹配不符合內部 filter 的值

讓我們看一個 nnx.vmap 範例,來看看 DSL 的實際應用。假設我們想要在第 0 軸上向量化所有參數和 dropout Rng(Keys|Counts),並廣播其餘部分。為此,我們可以使用以下篩選器來定義一個 nnx.StateAxes 物件,我們可以將其傳遞給 nnx.vmapin_axes,以指定應如何向量化 model 的各種子狀態

state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})

@nnx.vmap(in_axes=(state_axes, 0))
def forward(model, x):
  ...

在這裡,(nnx.Param, 'dropout') 展開為 Any(OfType(nnx.Param), WithTag('dropout')),而 ... 展開為 Everything()

如果您希望手動將文字轉換為謂詞,則可以使用 nnx.filterlib.to_predicate

is_param = nnx.filterlib.to_predicate(nnx.Param)
everything = nnx.filterlib.to_predicate(...)
nothing = nnx.filterlib.to_predicate(False)
params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))

print(f'{is_param = }')
print(f'{everything = }')
print(f'{nothing = }')
print(f'{params_or_dropout = }')
is_param = OfType(<class 'flax.nnx.variablelib.Param'>)
everything = Everything()
nothing = Nothing()
params_or_dropout = Any(OfType(<class 'flax.nnx.variablelib.Param'>), WithTag('dropout'))

群組化狀態#

在掌握篩選器的知識後,讓我們看看 nnx.split 是如何大致實作的。關鍵概念

  • 使用 nnx.graph.flatten 取得節點的 GraphDefState 表示。

  • 將所有篩選器轉換為謂詞。

  • 使用 State.flat_state 取得狀態的平面表示。

  • 遍歷平面狀態中的所有 (path, value) 對,並根據謂詞對它們進行分組。

  • 使用 State.from_flat_state 將平面狀態轉換為巢狀 State

from typing import Any
KeyPath = tuple[nnx.graph.Key, ...]

def split(node, *filters):
  graphdef, state = nnx.graph.flatten(node)
  predicates = [nnx.filterlib.to_predicate(f) for f in filters]
  flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]

  for path, value in state.flat_state():
    for i, predicate in enumerate(predicates):
      if predicate(path, value):
        flat_states[i][path] = value
        break
    else:
      raise ValueError(f'No filter matched {path = } {value = }')

  states: tuple[nnx.GraphState, ...] = tuple(
    nnx.State.from_flat_path(flat_state) for flat_state in flat_states
  )
  return graphdef, *states

# lets test it...
foo = Foo()

graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)

print(f'{params = }')
print(f'{batch_stats = }')
params = State({
  'a': VariableState(
    type=Param,
    value=0
  )
})
batch_stats = State({
  'b': VariableState(
    type=BatchStat,
    value=True
  )
})

一個非常重要的事情要注意的是,篩選是依賴順序的。第一個匹配值篩選器會保留該值,因此您應該將更具體的篩選器放在更通用的篩選器之前。例如,如果我們建立一個屬於 Param 子類的 SpecialParam 類型,以及一個包含兩種參數類型的 Bar 物件,如果我們嘗試在 SpecialParam 之前分割 Param,那麼所有值都將被放置在 Param 群組中,而 SpecialParam 群組將為空,因為所有 SpecialParam 也是 Param

class SpecialParam(nnx.Param):
  pass

class Bar(nnx.Module):
  def __init__(self):
    self.a = nnx.Param(0)
    self.b = SpecialParam(0)

bar = Bar()

graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!
print(f'{params = }')
print(f'{special_params = }')
params = State({
  'a': VariableState(
    type=Param,
    value=0
  ),
  'b': VariableState(
    type=SpecialParam,
    value=0
  )
})
special_params = State({})

反轉順序將確保首先捕獲 SpecialParam

graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!
print(f'{params = }')
print(f'{special_params = }')
params = State({
  'a': VariableState(
    type=Param,
    value=0
  )
})
special_params = State({
  'b': VariableState(
    type=SpecialParam,
    value=0
  )
})