flax.struct 套件

flax.struct 套件#

用於定義可與 jax 轉換一起使用的自訂類別的工具。

flax.struct.dataclass(clz=None, **kwargs)[原始碼]#

建立一個可以傳遞到函數式轉換的類別。

注意

繼承自 PyTreeNode,以避免在使用 PyType 時發生類型檢查問題。

諸如 jax.jitjax.grad 等 Jax 轉換需要不可變的物件,並且可以使用 jax.tree_util 方法進行映射。dataclass 裝飾器可以輕鬆定義可以安全地傳遞給 Jax 的自訂類別。例如

>>> from flax import struct
>>> import jax
>>> from typing import Any, Callable

>>> @struct.dataclass
... class Model:
...   params: Any
...   # use pytree_node=False to indicate an attribute should not be touched
...   # by Jax transformations.
...   apply_fn: Callable = struct.field(pytree_node=False)

...   def __apply__(self, *args):
...     return self.apply_fn(*args)

>>> params = {}
>>> params_b = {}
>>> apply_fn = lambda v, x: x
>>> model = Model(params, apply_fn)

>>> # model.params = params_b  # Model is immutable. This will raise an error.
>>> model_b = model.replace(params=params_b)  # Use the replace method instead.

>>> # This class can now be used safely in Jax to compute gradients w.r.t. the
>>> # parameters.
>>> model = Model(params, apply_fn)
>>> loss_fn = lambda model: 3.
>>> model_grad = jax.grad(loss_fn)(model)

請注意,資料類別具有自動產生的 __init__,其中建構子的引數和所建立實例的屬性是一對一匹配的。這種對應關係使得這些物件成為有效的容器,可以與 JAX 轉換以及更廣泛的 jax.tree_util 程式庫一起使用。

有時需要一個「智慧建構子」,例如因為某些屬性可以(選擇性地)從其他屬性衍生而來。使用 Flax 資料類別執行此操作的方法是建立一個提供智慧建構子的靜態或類別方法。這樣,jax.tree_util 使用的簡單建構子就能夠保留。請看以下範例

>>> @struct.dataclass
... class DirectionAndScaleKernel:
...   direction: jax.Array
...   scale: jax.Array

...   @classmethod
...   def create(cls, kernel):
...     scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True)
...     direction = direction / scale
...     return cls(direction, scale)
參數
  • clz – 將由裝飾器轉換的類別。

  • **kwargs – 要傳遞給資料類別建構子的引數。

回傳值

新的類別。

class flax.struct.PyTreeNode(*args, **kwargs)[原始碼]#

應該像 JAX pytree 節點一樣運作的資料類別的基底類別。

請參閱 flax.struct.dataclass 以了解 jax.tree_util 行為。此基底類別另外避免在使用 PyType 時發生類型檢查錯誤。

範例

>>> from flax import struct
>>> import jax
>>> from typing import Any, Callable

>>> class Model(struct.PyTreeNode):
...   params: Any
...   # use pytree_node=False to indicate an attribute should not be touched
...   # by Jax transformations.
...   apply_fn: Callable = struct.field(pytree_node=False)

...   def __apply__(self, *args):
...     return self.apply_fn(*args)

>>> params = {}
>>> params_b = {}
>>> apply_fn = lambda v, x: x
>>> model = Model(params, apply_fn)

>>> # model.params = params_b  # Model is immutable. This will raise an error.
>>> model_b = model.replace(params=params_b)  # Use the replace method instead.

>>> # This class can now be used safely in Jax to compute gradients w.r.t. the
>>> # parameters.
>>> model = Model(params, apply_fn)
>>> loss_fn = lambda model: 3.
>>> model_grad = jax.grad(loss_fn)(model)