flax.traverse_util 套件#
一個用於遍歷不可變資料結構的工具。
Traversal 可用於迭代和更新複雜資料結構。Traversal 會採用一個物件並回傳其部分內容。例如,Traversal 可以選擇一個物件的屬性
>>> from flax import traverse_util
>>> import dataclasses
>>> @dataclasses.dataclass
... class Foo:
... foo: int = 0
... bar: int = 0
...
>>> x = Foo(foo=1)
>>> iterator = traverse_util.TraverseAttr('foo').iterate(x)
>>> list(iterator)
[1]
可以使用組合來建構更複雜的 Traversal。通常可以從身分識別 Traversal 開始,然後使用方法鏈來建構預期的 Traversal
>>> data = [{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 4}]
>>> traversal = traverse_util.t_identity.each()['foo']
>>> iterator = traversal.iterate(data)
>>> list(iterator)
[1, 3]
Traversal 也可以透過使用update
方法來進行變更
>>> data = {'foo': Foo(bar=2)}
>>> traversal = traverse_util.t_identity['foo'].bar
>>> data = traversal.update(lambda x: x + x, data)
>>> data
{'foo': Foo(foo=0, bar=4)}
Traversal 從來不會改變原始資料。因此,更新基本上會回傳一個包含已提供更新的資料副本。
Traversal 物件#
Dict utils#
- flax.traverse_util.flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None)[source]#
Flatten a nested dictionary.
The nested keys are flattened to a tuple. See
unflatten_dict
on how to restore the nested dictionary structure.Example
>>> from flax.traverse_util import flatten_dict >>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} >>> flat_xs = flatten_dict(xs) >>> flat_xs {('foo',): 1, ('bar', 'a'): 2}
Note that empty dictionaries are ignored and will not be restored by
unflatten_dict
.- 參數
xs – a nested dictionary
keep_empty_nodes – replaces empty dictionaries with
traverse_util.empty_node
.is_leaf – 一個選用函式,它會使用下一個巢狀字典和巢狀金鑰,並傳回 True,如果巢狀字典為葉節點(即,不應進一步壓平)。
sep – 如果有指定,傳回字典的金鑰將會是
sep
連接字串(如果None
,則金鑰將會是組)。
- 傳回
壓平後的字典。
- flax.traverse_util.unflatten_dict(xs, sep=None)[source]#
取消壓平一個字典。
請參閱
flatten_dict
Example
>>> flat_xs = { ... ('foo',): 1, ... ('bar', 'a'): 2, ... } >>> xs = unflatten_dict(flat_xs) >>> xs {'foo': 1, 'bar': {'a': 2}}
- 參數
xs – 壓平後的字典
sep – 分隔符號(和
flatten_dict()
搭配使用時相同)。
- 傳回
巢狀字典。
- flax.traverse_util.path_aware_map(f, nested_dict)[source]#
一個映射函式,它在作業巢狀字典結構時,會將路徑納入考量,以處理每個葉節點。
Example
>>> import jax.numpy as jnp >>> from flax import traverse_util >>> params = {'a': {'x': 10, 'y': 3}, 'b': {'x': 20}} >>> f = lambda path, x: x + 5 if 'x' in path else -x >>> traverse_util.path_aware_map(f, params) {'a': {'x': 15, 'y': -3}, 'b': {'x': 25}}
- 參數
f – 一個接受
(path, value)
參數,並將它們對應到新值的呼叫函式。這裡的path
是字串的組。nested_dict – 巢狀字典結構。
- 傳回
擁有對應值的新巢狀字典結構。
模型參數橫跨#
- class flax.traverse_util.ModelParamTraversal(*args, **kwargs)[source]#
使用名稱篩選器選取模型參數。
這個橫跨作業於一個參數的巢狀字典,並根據
filter_fn
參數選取子集。請參閱
flax.optim.MultiOptimizer
,了解如何使用ModelParamTraversal
更新特定優化器中的參數樹子集的範例。