圖表#

flax.nnx.split(node, *filters)[原始碼]#

將圖形節點分割成 GraphDef 和一個或多個 State`s。 State 一個 ``Mapping`,從字串或整數到 Variables、陣列或巢狀 State。 GraphDef 包含重建 Module 圖形所需的所有靜態資訊,它類似於 JAX 的 PyTreeDefsplit()merge() 結合使用,以在圖形的有狀態和無狀態表示之間無縫切換。

使用範例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> class Foo(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...
>>> node = Foo(nnx.Rngs(0))
>>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat)
...
>>> jax.tree.map(jnp.shape, params)
State({
  'batch_norm': {
    'bias': VariableState(
      type=Param,
      value=(2,)
    ),
    'scale': VariableState(
      type=Param,
      value=(2,)
    )
  },
  'linear': {
    'bias': VariableState(
      type=Param,
      value=(3,)
    ),
    'kernel': VariableState(
      type=Param,
      value=(2, 3)
    )
  }
})
>>> jax.tree.map(jnp.shape, batch_stats)
State({
  'batch_norm': {
    'mean': VariableState(
      type=BatchStat,
      value=(2,)
    ),
    'var': VariableState(
      type=BatchStat,
      value=(2,)
    )
  }
})

split()merge() 主要用於直接與 JAX 轉換互動,請參閱 Functional API 以取得更多資訊。

參數
  • node – 要分割的圖形節點。

  • *filters – 一些可選的篩選器,將狀態分組為互斥的子狀態。

回傳值

GraphDef 和一個或多個 States,等於傳遞的篩選器數量。如果未傳遞任何篩選器,則回傳單個 State

flax.nnx.merge(graphdef, state, /, *states)[原始碼]#

flax.nnx.split() 相反。

nnx.merge 接收一個 flax.nnx.GraphDef 和一個或多個 flax.nnx.State,並建立一個與原始節點具有相同結構的新節點。

回顧:flax.nnx.split() 用於表示 flax.nnx.Module:1) 一個靜態 nnx.GraphDef,捕獲其 Pythonic 靜態資訊;以及 2) 一個或多個 flax.nnx.Variable nnx.State'(s),捕獲其 jax.Array's,以 JAX pytrees 的形式。

nnx.mergennx.split 結合使用,以在圖形的有狀態和無狀態表示之間無縫切換。

使用範例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> class Foo(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...
>>> node = Foo(nnx.Rngs(0))
>>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat)
...
>>> new_node = nnx.merge(graphdef, params, batch_stats)
>>> assert isinstance(new_node, Foo)
>>> assert isinstance(new_node.batch_norm, nnx.BatchNorm)
>>> assert isinstance(new_node.linear, nnx.Linear)

nnx.splitnnx.merge 主要用於直接與 JAX 轉換互動(請參閱 Functional API 以取得更多資訊)。

參數
回傳值

合併後的 flax.nnx.Module

flax.nnx.update(node, state, /, *states)[原始碼]#

使用新的狀態(s) 就地更新給定的圖形節點。

使用範例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))

>>> def loss_fn(model, x, y):
...   return jnp.mean((y - model(x))**2)
>>> prev_loss = loss_fn(model, x, y)

>>> grads = nnx.grad(loss_fn)(model, x, y)
>>> new_state = jax.tree.map(lambda p, g: p - 0.1*g, nnx.state(model), grads)
>>> nnx.update(model, new_state)
>>> assert loss_fn(model, x, y) < prev_loss
參數
  • node – 要更新的圖形節點。

  • state – 一個 State 物件。

  • *states – 額外的 State 物件。

flax.nnx.pop(node, *filters)[原始碼]#

從圖形節點中彈出一個或多個 Variable 類型。

使用範例

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')
>>> y = model(x)
>>> assert hasattr(model, 'i')

>>> intermediates = nnx.pop(model, nnx.Intermediate)
>>> assert intermediates['i'].value[0].shape == (1, 3)
>>> assert not hasattr(model, 'i')
參數
  • node – 圖形節點物件。

  • *filters – 一個或多個要篩選的 Variable 物件。

回傳值

彈出的 State,其中包含篩選的 Variable 物件。

flax.nnx.state(node, *filters)[原始碼]#

類似於 split(),但只會返回由篩選器指定的 State

使用範例

>>> from flax import nnx

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batch_norm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> # get the learnable parameters from the batch norm and linear layer
>>> params = nnx.state(model, nnx.Param)
>>> # get the batch statistics from the batch norm layer
>>> batch_stats = nnx.state(model, nnx.BatchStat)
>>> # get them separately
>>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat)
>>> # get them together
>>> state = nnx.state(model)
參數
  • node – 圖形節點物件。

  • *filters – 一個或多個要篩選的 Variable 物件。

回傳值

一個或多個 State 的映射。

flax.nnx.variables(node, *filters)[原始碼]#

類似於 state(),但返回目前的 Variable 物件,而不是新的 VariableState 實例。

範例

>>> from flax import nnx
...
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> params = nnx.variables(model, nnx.Param)
...
>>> assert params['kernel'] is model.kernel
>>> assert params['bias'] is model.bias
參數
  • node – 圖形節點物件。

  • *filters – 一個或多個要篩選的 Variable 物件。

回傳值

一個或多個包含 Variable 物件的 State 映射。

flax.nnx.graph()#
flax.nnx.graphdef(node, /)[原始碼]#

取得給定圖形節點的 GraphDef

使用範例

>>> from flax import nnx

>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> graphdef, _ = nnx.split(model)
>>> assert graphdef == nnx.graphdef(model)
參數

node – 圖形節點物件。

回傳值

Module 物件的 GraphDef

flax.nnx.iter_graph(node, /)[原始碼]#

迭代給定圖形節點的所有巢狀節點和葉節點,包括目前的節點。

iter_graph 建立一個產生器,產生路徑和值對,其中路徑是一個字串或整數元組,表示從根節點到值的路徑。重複的節點只會拜訪一次。葉節點包含靜態值。

範例:
>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> class Linear(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.din, self.dout = din, dout
...     self.w = nnx.Param(jax.random.uniform(rngs.next(), (din, dout)))
...     self.b = nnx.Param(jnp.zeros((dout,)))
...
>>> module = Linear(3, 4, rngs=nnx.Rngs(0))
>>> graph = [module, module]
...
>>> for path, value in nnx.iter_graph(graph):
...   print(path, type(value).__name__)
...
(0, 'b') Param
(0, 'din') int
(0, 'dout') int
(0, 'w') Param
(0,) Linear
() list
flax.nnx.clone(node)[原始碼]#

建立給定圖形節點的深層副本。

使用範例

>>> from flax import nnx

>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> cloned_model = nnx.clone(model)
>>> model.bias.value += 1
>>> assert (model.bias.value != cloned_model.bias.value).all()
參數

node – 圖形節點物件。

回傳值

Module 物件的深層副本。

flax.nnx.call(graphdef_state, /)[原始碼]#

呼叫由 (GraphDef, State) 對定義的底層圖形節點的方法。

call 採用 (GraphDef, State) 對,並建立一個 Proxy 物件,該物件可用於呼叫底層圖形節點上的方法。當方法被呼叫時,會傳回輸出,以及一個新的 (GraphDef, State) 對,表示圖形節點的更新狀態。call 等同於 merge() > method > split(),但在純 JAX 函式中更方便使用。

範例

>>> from flax import nnx
>>> import jax
>>> import jax.numpy as jnp
...
>>> class StatefulLinear(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
...     self.b = nnx.Param(jnp.zeros((dout,)))
...     self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32))
...
...   def increment(self):
...     self.count += 1
...
...   def __call__(self, x):
...     self.increment()
...     return x @ self.w + self.b
...
>>> linear = StatefulLinear(3, 2, nnx.Rngs(0))
>>> linear_state = nnx.split(linear)
...
>>> @jax.jit
... def forward(x, linear_state):
...   y, linear_state = nnx.call(linear_state)(x)
...   return y, linear_state
...
>>> x = jnp.ones((1, 3))
>>> y, linear_state = forward(x, linear_state)
>>> y, linear_state = forward(x, linear_state)
...
>>> linear = nnx.merge(*linear_state)
>>> linear.count.value
Array(2, dtype=uint32)

call 傳回的 Proxy 物件支援索引和屬性存取,以存取巢狀方法。在下面的範例中,索引 increment 方法用於呼叫 nodes 字典的 b 鍵的 StatefulLinear 模組的 increment 方法。

>>> class StatefulLinear(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
...     self.b = nnx.Param(jnp.zeros((dout,)))
...     self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32))
...
...   def increment(self):
...     self.count += 1
...
...   def __call__(self, x):
...     self.increment()
...     return x @ self.w + self.b
...
>>> rngs = nnx.Rngs(0)
>>> nodes = dict(
...   a=StatefulLinear(3, 2, rngs),
...   b=StatefulLinear(2, 1, rngs),
... )
...
>>> node_state = nnx.split(nodes)
>>> # use attribute access
>>> _, node_state = nnx.call(node_state)['b'].increment()
...
>>> nodes = nnx.merge(*node_state)
>>> nodes['a'].count.value
Array(0, dtype=uint32)
>>> nodes['b'].count.value
Array(1, dtype=uint32)
class flax.nnx.GraphDef[原始碼]#

一個類別,代表 Flax Module 的所有靜態、無狀態和 Pythonic 部分。可以透過在 Module 上呼叫 split()graphdef() 來產生 GraphDef

class flax.nnx.UpdateContext(tag, ref_index, index_ref)[原始碼]#

用於處理複雜狀態更新的上下文管理器。

merge(graphdef, state, *states)[原始碼]#
split(node, *filters)[原始碼]#

將圖形節點分割成 GraphDef 和一個或多個 State`s。 State 一個 ``Mapping`,從字串或整數到 Variables、陣列或巢狀 State。 GraphDef 包含重建 Module 圖形所需的所有靜態資訊,它類似於 JAX 的 PyTreeDefsplit()merge() 結合使用,以在圖形的有狀態和無狀態表示之間無縫切換。

使用範例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> class Foo(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...
>>> node = Foo(nnx.Rngs(0))
>>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat)
...
>>> jax.tree.map(jnp.shape, params)
State({
  'batch_norm': {
    'bias': VariableState(
      type=Param,
      value=(2,)
    ),
    'scale': VariableState(
      type=Param,
      value=(2,)
    )
  },
  'linear': {
    'bias': VariableState(
      type=Param,
      value=(3,)
    ),
    'kernel': VariableState(
      type=Param,
      value=(2, 3)
    )
  }
})
>>> jax.tree.map(jnp.shape, batch_stats)
State({
  'batch_norm': {
    'mean': VariableState(
      type=BatchStat,
      value=(2,)
    ),
    'var': VariableState(
      type=BatchStat,
      value=(2,)
    )
  }
})
參數
  • node – 要分割的圖形節點。

  • *filters – 一些可選的篩選器,將狀態分組為互斥的子狀態。

回傳值

GraphDef 和一個或多個 State,數量等於傳遞的篩選器數量。如果沒有傳遞篩選器,則會傳回單個 State

flax.nnx.update_context(tag)[原始碼]#

建立一個 UpdateContext 上下文管理器,可用於處理比 nnx.update 可以處理的更複雜的狀態更新,包括靜態屬性和圖形結構的更新。

UpdateContext 公開一個 splitmerge API,其簽名與 nnx.split / nnx.merge 相同,但會執行一些簿記,以具有必要的資訊,以便根據轉換內所做的變更完美地更新輸入物件。UpdateContext 必須總共呼叫 split 和 merge 4 次,第一次和最後一次呼叫發生在轉換之外,第二次和第三次呼叫發生在轉換內部,如下圖所示

                      idxmap
(2) merge ─────────────────────────────► split (3)
      ▲                                    │
      │               inside               │
      │. . . . . . . . . . . . . . . . . . │ index_mapping
      │               outside              │
      │                                    ▼
(1) split──────────────────────────────► merge (4)
                      refmap

第一次呼叫 split (1) 會建立一個 refmap,用來追蹤外部參照;而第一次呼叫 merge (2) 會建立一個 idxmap,用來追蹤內部參照。第二次呼叫 split (3) 會結合 refmap 和 idxmap 來產生 index_mapping,它會指出外部參照如何對應到內部參照。最後,最後一次呼叫 merge (4) 會使用 index_mapping 和 refmap 來重建轉換的輸出,同時重複使用/更新內部參照。為了避免記憶體洩漏,idxmap 會在 (3) 之後清除,而 refmap 會在 (4) 之後清除,兩者都會在上下文管理器結束後清除。

以下是一個簡單的範例,展示如何使用 update_context

>>> from flax import nnx
...
>>> m1 = nnx.Dict({})
>>> with nnx.update_context('example') as ctx:
...   graphdef, state = ctx.split(m1)
...   @jax.jit
...   def f(graphdef, state):
...     m2 = ctx.merge(graphdef, state)
...     m2.a = 1
...     m2.ref = m2  # create a reference cycle
...     return ctx.split(m2)
...   graphdef_out, state_out = f(graphdef, state)
...   m3 = ctx.merge(graphdef_out, state_out)
...
>>> assert m1 is m3
>>> assert m1.a == 1
>>> assert m1.ref is m1

請注意,update_context 接受一個 tag 引數,主要用作安全機制,以減少在使用 current_update_context() 存取目前作用中的上下文時,意外使用錯誤 UpdateContext 的風險。current_update_context 可以用作存取目前作用中的上下文的方法,而無需將其作為捕獲傳遞。

>>> from flax import nnx
...
>>> m1 = nnx.Dict({})
>>> @jax.jit
... def f(graphdef, state):
...   ctx = nnx.current_update_context('example')
...   m2 = ctx.merge(graphdef, state)
...   m2.a = 1     # insert static attribute
...   m2.ref = m2  # create a reference cycle
...   return ctx.split(m2)
...
>>> @nnx.update_context('example')
... def g(m1):
...   ctx = nnx.current_update_context('example')
...   graphdef, state = ctx.split(m1)
...   graphdef_out, state_out = f(graphdef, state)
...   return ctx.merge(graphdef_out, state_out)
...
>>> m3 = g(m1)
>>> assert m1 is m3
>>> assert m1.a == 1
>>> assert m1.ref is m1

如以上程式碼所示,update_context 也可用作裝飾器,在函式執行期間建立/啟用 UpdateContext 上下文。可以使用 current_update_context() 存取上下文。

參數

tag – 用於識別上下文的字串標籤。

flax.nnx.current_update_context(tag)[原始碼]#

返回給定標籤的目前作用中的 UpdateContext