flax.traverse_util 套件#

一個用於遍歷不可變資料結構的工具。

Traversal 可用於迭代和更新複雜資料結構。Traversal 會採用一個物件並回傳其部分內容。例如,Traversal 可以選擇一個物件的屬性

>>> from flax import traverse_util
>>> import dataclasses

>>> @dataclasses.dataclass
... class Foo:
...   foo: int = 0
...   bar: int = 0
...
>>> x = Foo(foo=1)
>>> iterator = traverse_util.TraverseAttr('foo').iterate(x)
>>> list(iterator)
[1]

可以使用組合來建構更複雜的 Traversal。通常可以從身分識別 Traversal 開始,然後使用方法鏈來建構預期的 Traversal

>>> data = [{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 4}]
>>> traversal = traverse_util.t_identity.each()['foo']
>>> iterator = traversal.iterate(data)
>>> list(iterator)
[1, 3]

Traversal 也可以透過使用update 方法來進行變更

>>> data = {'foo': Foo(bar=2)}
>>> traversal = traverse_util.t_identity['foo'].bar
>>> data = traversal.update(lambda x: x + x, data)
>>> data
{'foo': Foo(foo=0, bar=4)}

Traversal 從來不會改變原始資料。因此,更新基本上會回傳一個包含已提供更新的資料副本。

Traversal 物件#

類別 flax.traverse_util.Traversal(*args, **kwargs)[原始程式碼]#

所有 Traversal 的基底類別。

編寫(other)[原始程式碼]#

組合兩個 Traversal。

每個()[原始程式碼]#

遍歷選取容器中的每個項目。

篩選(fn)[原始程式碼]#

過濾選取的值。

摘要 反覆處理(輸入)[來源]#

反覆處理由 Traversal 選擇的值。

參數

輸入 – 應反覆處理的物件。

傳回

反覆處理過的值的迭代器。

合併(*反覆處理)[來源]#

組合任意數量的反覆處理並合併結果。

設定(, 輸入)[來源]#

覆寫 Traversal 選擇的值。

參數
  • – 包含新值的清單。

  • 輸入 – 應反覆處理的物件。

傳回

具有更新值的清單。

樹狀圖()[來源]#

反覆處理 pytree 中的每個項目。

摘要 更新(fn, 輸入)[來源]#

更新焦點項目。

參數
  • fn – 將每個反覆處理項目對應至其更新值之回呼函式。

  • 輸入 – 應反覆處理的物件。

傳回

具有更新值的清單。

類別 flax.traverse_util.TraverseId(*args, **kwargs)[原始程式碼]#

身分識別的 Traversal。

遍歷(輸入)[原始程式碼]#

反覆處理由 Traversal 選擇的值。

參數

輸入 – 應反覆處理的物件。

傳回

反覆處理過的值的迭代器。

更新(fn, 輸入)[原始程式碼]#

更新焦點項目。

參數
  • fn – 將每個反覆處理項目對應至其更新值之回呼函式。

  • 輸入 – 應反覆處理的物件。

傳回

具有更新值的清單。

類別 flax.traverse_util.TraverseMerge(*args, **kwargs)[原始程式碼]#

合併一組 traversal 的選取結果。

遍歷(輸入)[原始程式碼]#

反覆處理由 Traversal 選擇的值。

參數

輸入 – 應反覆處理的物件。

傳回

反覆處理過的值的迭代器。

更新(fn, 輸入)[原始程式碼]#

更新焦點項目。

參數
  • fn – 將每個反覆處理項目對應至其更新值之回呼函式。

  • 輸入 – 應反覆處理的物件。

傳回

具有更新值的清單。

類別 flax.traverse_util.TraverseCompose(*args, **kwargs)[原始碼]#

組合兩個 Traversal。

反覆處理(輸入)[原始碼]#

反覆處理由 Traversal 選擇的值。

參數

輸入 – 應反覆處理的物件。

傳回

反覆處理過的值的迭代器。

更新(函式, 輸入)[原始碼]#

更新焦點項目。

參數
  • fn – 將每個反覆處理項目對應至其更新值之回呼函式。

  • 輸入 – 應反覆處理的物件。

傳回

具有更新值的清單。

類別 flax.traverse_util.TraverseFilter(*args, **kwargs)[原始碼]#

根據謂詞篩選選取值。

反覆處理(輸入)[原始碼]#

反覆處理由 Traversal 選擇的值。

參數

輸入 – 應反覆處理的物件。

傳回

反覆處理過的值的迭代器。

更新(函式, 輸入)[原始碼]#

更新焦點項目。

參數
  • fn – 將每個反覆處理項目對應至其更新值之回呼函式。

  • 輸入 – 應反覆處理的物件。

傳回

具有更新值的清單。

類別 flax.traverse_util.TraverseAttr(*args, **kwargs)[原始碼]#

反覆處理物件屬性。

iterate(inputs)[來源]#

反覆處理由 Traversal 選擇的值。

參數

輸入 – 應反覆處理的物件。

傳回

反覆處理過的值的迭代器。

update(fn, inputs)[來源]#

更新焦點項目。

參數
  • fn – 將每個反覆處理項目對應至其更新值之回呼函式。

  • 輸入 – 應反覆處理的物件。

傳回

具有更新值的清單。

類別 flax.traverse_util.TraverseItem(*args, **kwargs)[來源]#

遍歷物件的項目。

iterate(inputs)[來源]#

反覆處理由 Traversal 選擇的值。

參數

輸入 – 應反覆處理的物件。

傳回

反覆處理過的值的迭代器。

update(fn, inputs)[來源]#

更新焦點項目。

參數
  • fn – 將每個反覆處理項目對應至其更新值之回呼函式。

  • 輸入 – 應反覆處理的物件。

傳回

具有更新值的清單。

類別 flax.traverse_util.TraverseEach(*args, **kwargs)[來源]#

遍歷儲存器中的每一個項目。

iterate(inputs)[來源]#

反覆處理由 Traversal 選擇的值。

參數

輸入 – 應反覆處理的物件。

傳回

反覆處理過的值的迭代器。

update(fn, inputs)[source]#

更新焦點項目。

參數
  • fn – 將每個反覆處理項目對應至其更新值之回呼函式。

  • 輸入 – 應反覆處理的物件。

傳回

具有更新值的清單。

class flax.traverse_util.TraverseTree(*args, **kwargs)[source]#

Traverse every item in a pytree.

iterate(inputs)[source]#

反覆處理由 Traversal 選擇的值。

參數

輸入 – 應反覆處理的物件。

傳回

反覆處理過的值的迭代器。

update(fn, inputs)[source]#

更新焦點項目。

參數
  • fn – 將每個反覆處理項目對應至其更新值之回呼函式。

  • 輸入 – 應反覆處理的物件。

傳回

具有更新值的清單。

Dict utils#

flax.traverse_util.flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None)[source]#

Flatten a nested dictionary.

The nested keys are flattened to a tuple. See unflatten_dict on how to restore the nested dictionary structure.

Example

>>> from flax.traverse_util import flatten_dict

>>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
>>> flat_xs = flatten_dict(xs)
>>> flat_xs
{('foo',): 1, ('bar', 'a'): 2}

Note that empty dictionaries are ignored and will not be restored by unflatten_dict.

參數
  • xs – a nested dictionary

  • keep_empty_nodes – replaces empty dictionaries with traverse_util.empty_node.

  • is_leaf – 一個選用函式,它會使用下一個巢狀字典和巢狀金鑰,並傳回 True,如果巢狀字典為葉節點(即,不應進一步壓平)。

  • sep – 如果有指定,傳回字典的金鑰將會是 sep 連接字串(如果 None,則金鑰將會是組)。

傳回

壓平後的字典。

flax.traverse_util.unflatten_dict(xs, sep=None)[source]#

取消壓平一個字典。

請參閱 flatten_dict

Example

>>> flat_xs = {
...   ('foo',): 1,
...   ('bar', 'a'): 2,
... }
>>> xs = unflatten_dict(flat_xs)
>>> xs
{'foo': 1, 'bar': {'a': 2}}
參數
  • xs – 壓平後的字典

  • sep – 分隔符號(和 flatten_dict() 搭配使用時相同)。

傳回

巢狀字典。

flax.traverse_util.path_aware_map(f, nested_dict)[source]#

一個映射函式,它在作業巢狀字典結構時,會將路徑納入考量,以處理每個葉節點。

Example

>>> import jax.numpy as jnp
>>> from flax import traverse_util

>>> params = {'a': {'x': 10, 'y': 3}, 'b': {'x': 20}}
>>> f = lambda path, x: x + 5 if 'x' in path else -x
>>> traverse_util.path_aware_map(f, params)
{'a': {'x': 15, 'y': -3}, 'b': {'x': 25}}
參數
  • f – 一個接受 (path, value) 參數,並將它們對應到新值的呼叫函式。這裡的 path 是字串的組。

  • nested_dict – 巢狀字典結構。

傳回

擁有對應值的新巢狀字典結構。

模型參數橫跨#

class flax.traverse_util.ModelParamTraversal(*args, **kwargs)[source]#

使用名稱篩選器選取模型參數。

這個橫跨作業於一個參數的巢狀字典,並根據 filter_fn 參數選取子集。

請參閱 flax.optim.MultiOptimizer,了解如何使用 ModelParamTraversal 更新特定優化器中的參數樹子集的範例。

__init__(filter_fn)[來源]#

建立新的 ModelParamTraversal。

參數

filter_fn – 一個接受參數完整名稱和其值的函式,回傳應否選取此參數。參數名稱由模組層級和參數名稱決定 (例如:‘/module/sub_module/parameter_name’)。