使用篩選器#
注意:此頁面與新的 Flax NNX API 相關。
篩選器在 Flax NNX 中被廣泛用作在諸如 nnx.split
、nnx.state
和許多 Flax NNX 轉換等 API 中建立 State
群組的方法。 例如
from flax import nnx
class Foo(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = nnx.BatchStat(True)
foo = Foo()
graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
batch_stats = State({
'b': VariableState(
type=BatchStat,
value=True
)
})
這裡,nnx.Param
和 nnx.BatchStat
被用作篩選器,將模型分成兩組:一組包含參數,另一組包含批次統計資訊。然而,這引發了以下問題
什麼是篩選器?
為什麼像
Param
或BatchStat
這樣的類型是篩選器?如何群組/篩選
State
?
篩選器協定#
一般來說,篩選器是以下形式的謂詞函數
(path: tuple[Key, ...], value: Any) -> bool
其中 Key
是一種可雜湊和可比較的類型,path
是一個由 Key
組成的元組,表示巢狀結構中值的路徑,而 value
是路徑上的值。如果該值應包含在群組中,則該函數返回 True
,否則返回 False
。
類型顯然不是這種形式的函數,因此它們被視為篩選器的原因是,正如我們接下來將看到的,類型和一些其他文字會轉換為謂詞。例如,Param
大致轉換為類似這樣的謂詞
def is_param(path, value) -> bool:
return isinstance(value, nnx.Param) or (
hasattr(value, 'type') and issubclass(value.type, nnx.Param)
)
print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True
此函數會匹配任何屬於 Param
實例的值,或具有 type
屬性並且該屬性是 Param
子類的值。在內部,Flax NNX 使用 OfType
,它為給定類型定義了這種形式的可呼叫物件
is_param = nnx.OfType(nnx.Param)
print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True
篩選器 DSL#
為了避免使用者必須建立這些函數,Flax NNX 公開了一個小的 DSL,以 nnx.filterlib.Filter
類型形式化,讓使用者可以傳遞類型、布林值、省略符號、元組/列表等,並在內部將它們轉換為適當的謂詞。
以下是 Flax NNX 中包含的所有可呼叫篩選器及其 DSL 文字(如果可用)的清單
文字 |
可呼叫 |
描述 |
---|---|---|
|
|
匹配所有值 |
|
|
不匹配任何值 |
|
|
匹配屬於 |
|
匹配具有包含給定 |
|
|
|
匹配具有與 |
|
|
匹配符合任何內部 |
|
匹配符合所有內部 |
|
|
匹配不符合內部 |
讓我們看一個 nnx.vmap
範例,來看看 DSL 的實際應用。假設我們想要在第 0 軸上向量化所有參數和 dropout
Rng(Keys|Counts),並廣播其餘部分。為此,我們可以使用以下篩選器來定義一個 nnx.StateAxes
物件,我們可以將其傳遞給 nnx.vmap
的 in_axes
,以指定應如何向量化 model
的各種子狀態
state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})
@nnx.vmap(in_axes=(state_axes, 0))
def forward(model, x):
...
在這裡,(nnx.Param, 'dropout')
展開為 Any(OfType(nnx.Param), WithTag('dropout'))
,而 ...
展開為 Everything()
。
如果您希望手動將文字轉換為謂詞,則可以使用 nnx.filterlib.to_predicate
is_param = nnx.filterlib.to_predicate(nnx.Param)
everything = nnx.filterlib.to_predicate(...)
nothing = nnx.filterlib.to_predicate(False)
params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))
print(f'{is_param = }')
print(f'{everything = }')
print(f'{nothing = }')
print(f'{params_or_dropout = }')
is_param = OfType(<class 'flax.nnx.variablelib.Param'>)
everything = Everything()
nothing = Nothing()
params_or_dropout = Any(OfType(<class 'flax.nnx.variablelib.Param'>), WithTag('dropout'))
群組化狀態#
在掌握篩選器的知識後,讓我們看看 nnx.split
是如何大致實作的。關鍵概念
使用
nnx.graph.flatten
取得節點的GraphDef
和State
表示。將所有篩選器轉換為謂詞。
使用
State.flat_state
取得狀態的平面表示。遍歷平面狀態中的所有
(path, value)
對,並根據謂詞對它們進行分組。使用
State.from_flat_state
將平面狀態轉換為巢狀State
。
from typing import Any
KeyPath = tuple[nnx.graph.Key, ...]
def split(node, *filters):
graphdef, state = nnx.graph.flatten(node)
predicates = [nnx.filterlib.to_predicate(f) for f in filters]
flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]
for path, value in state.flat_state():
for i, predicate in enumerate(predicates):
if predicate(path, value):
flat_states[i][path] = value
break
else:
raise ValueError(f'No filter matched {path = } {value = }')
states: tuple[nnx.GraphState, ...] = tuple(
nnx.State.from_flat_path(flat_state) for flat_state in flat_states
)
return graphdef, *states
# lets test it...
foo = Foo()
graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
batch_stats = State({
'b': VariableState(
type=BatchStat,
value=True
)
})
一個非常重要的事情要注意的是,篩選是依賴順序的。第一個匹配值篩選器會保留該值,因此您應該將更具體的篩選器放在更通用的篩選器之前。例如,如果我們建立一個屬於 Param
子類的 SpecialParam
類型,以及一個包含兩種參數類型的 Bar
物件,如果我們嘗試在 SpecialParam
之前分割 Param
,那麼所有值都將被放置在 Param
群組中,而 SpecialParam
群組將為空,因為所有 SpecialParam
也是 Param
class SpecialParam(nnx.Param):
pass
class Bar(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = SpecialParam(0)
bar = Bar()
graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!
print(f'{params = }')
print(f'{special_params = }')
params = State({
'a': VariableState(
type=Param,
value=0
),
'b': VariableState(
type=SpecialParam,
value=0
)
})
special_params = State({})
反轉順序將確保首先捕獲 SpecialParam
graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!
print(f'{params = }')
print(f'{special_params = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
special_params = State({
'b': VariableState(
type=SpecialParam,
value=0
)
})