圖表#
- flax.nnx.split(node, *filters)[原始碼]#
將圖形節點分割成
GraphDef
和一個或多個State`s。 State 是 一個 ``Mapping`
,從字串或整數到Variables
、陣列或巢狀 State。 GraphDef 包含重建Module
圖形所需的所有靜態資訊,它類似於 JAX 的PyTreeDef
。split()
與merge()
結合使用,以在圖形的有狀態和無狀態表示之間無縫切換。使用範例
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> jax.tree.map(jnp.shape, params) State({ 'batch_norm': { 'bias': VariableState( type=Param, value=(2,) ), 'scale': VariableState( type=Param, value=(2,) ) }, 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } }) >>> jax.tree.map(jnp.shape, batch_stats) State({ 'batch_norm': { 'mean': VariableState( type=BatchStat, value=(2,) ), 'var': VariableState( type=BatchStat, value=(2,) ) } })
split()
和merge()
主要用於直接與 JAX 轉換互動,請參閱 Functional API 以取得更多資訊。- 參數
node – 要分割的圖形節點。
*filters – 一些可選的篩選器,將狀態分組為互斥的子狀態。
- 回傳值
GraphDef
和一個或多個States
,等於傳遞的篩選器數量。如果未傳遞任何篩選器,則回傳單個State
。
- flax.nnx.merge(graphdef, state, /, *states)[原始碼]#
與
flax.nnx.split()
相反。nnx.merge
接收一個flax.nnx.GraphDef
和一個或多個flax.nnx.State
,並建立一個與原始節點具有相同結構的新節點。回顧:
flax.nnx.split()
用於表示flax.nnx.Module
:1) 一個靜態nnx.GraphDef
,捕獲其 Pythonic 靜態資訊;以及 2) 一個或多個flax.nnx.Variable
nnx.State
'(s),捕獲其jax.Array
's,以 JAX pytrees 的形式。nnx.merge
與nnx.split
結合使用,以在圖形的有狀態和無狀態表示之間無縫切換。使用範例
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> new_node = nnx.merge(graphdef, params, batch_stats) >>> assert isinstance(new_node, Foo) >>> assert isinstance(new_node.batch_norm, nnx.BatchNorm) >>> assert isinstance(new_node.linear, nnx.Linear)
nnx.split
和nnx.merge
主要用於直接與 JAX 轉換互動(請參閱 Functional API 以取得更多資訊)。- 參數
graphdef – 一個
flax.nnx.GraphDef
物件。state – 一個
flax.nnx.State
物件。*states – 額外的
flax.nnx.State
物件。
- 回傳值
合併後的
flax.nnx.Module
。
- flax.nnx.update(node, state, /, *states)[原始碼]#
使用新的狀態(s) 就地更新給定的圖形節點。
使用範例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> def loss_fn(model, x, y): ... return jnp.mean((y - model(x))**2) >>> prev_loss = loss_fn(model, x, y) >>> grads = nnx.grad(loss_fn)(model, x, y) >>> new_state = jax.tree.map(lambda p, g: p - 0.1*g, nnx.state(model), grads) >>> nnx.update(model, new_state) >>> assert loss_fn(model, x, y) < prev_loss
- flax.nnx.pop(node, *filters)[原始碼]#
從圖形節點中彈出一個或多個
Variable
類型。使用範例
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'i') >>> y = model(x) >>> assert hasattr(model, 'i') >>> intermediates = nnx.pop(model, nnx.Intermediate) >>> assert intermediates['i'].value[0].shape == (1, 3) >>> assert not hasattr(model, 'i')
- flax.nnx.state(node, *filters)[原始碼]#
類似於
split()
,但只會返回由篩選器指定的State
。使用範例
>>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batch_norm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> # get the learnable parameters from the batch norm and linear layer >>> params = nnx.state(model, nnx.Param) >>> # get the batch statistics from the batch norm layer >>> batch_stats = nnx.state(model, nnx.BatchStat) >>> # get them separately >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat) >>> # get them together >>> state = nnx.state(model)
- flax.nnx.variables(node, *filters)[原始碼]#
類似於
state()
,但返回目前的Variable
物件,而不是新的VariableState
實例。範例
>>> from flax import nnx ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> params = nnx.variables(model, nnx.Param) ... >>> assert params['kernel'] is model.kernel >>> assert params['bias'] is model.bias
- flax.nnx.graph()#
- flax.nnx.graphdef(node, /)[原始碼]#
取得給定圖形節點的
GraphDef
。使用範例
>>> from flax import nnx >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> graphdef, _ = nnx.split(model) >>> assert graphdef == nnx.graphdef(model)
- flax.nnx.iter_graph(node, /)[原始碼]#
迭代給定圖形節點的所有巢狀節點和葉節點,包括目前的節點。
iter_graph
建立一個產生器,產生路徑和值對,其中路徑是一個字串或整數元組,表示從根節點到值的路徑。重複的節點只會拜訪一次。葉節點包含靜態值。- 範例:
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Linear(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.din, self.dout = din, dout ... self.w = nnx.Param(jax.random.uniform(rngs.next(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... >>> module = Linear(3, 4, rngs=nnx.Rngs(0)) >>> graph = [module, module] ... >>> for path, value in nnx.iter_graph(graph): ... print(path, type(value).__name__) ... (0, 'b') Param (0, 'din') int (0, 'dout') int (0, 'w') Param (0,) Linear () list
- flax.nnx.clone(node)[原始碼]#
建立給定圖形節點的深層副本。
使用範例
>>> from flax import nnx >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> cloned_model = nnx.clone(model) >>> model.bias.value += 1 >>> assert (model.bias.value != cloned_model.bias.value).all()
- 參數
node – 圖形節點物件。
- 回傳值
Module
物件的深層副本。
- flax.nnx.call(graphdef_state, /)[原始碼]#
呼叫由 (GraphDef, State) 對定義的底層圖形節點的方法。
call
採用(GraphDef, State)
對,並建立一個 Proxy 物件,該物件可用於呼叫底層圖形節點上的方法。當方法被呼叫時,會傳回輸出,以及一個新的 (GraphDef, State) 對,表示圖形節點的更新狀態。call
等同於merge()
>method
>split()
,但在純 JAX 函式中更方便使用。範例
>>> from flax import nnx >>> import jax >>> import jax.numpy as jnp ... >>> class StatefulLinear(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32)) ... ... def increment(self): ... self.count += 1 ... ... def __call__(self, x): ... self.increment() ... return x @ self.w + self.b ... >>> linear = StatefulLinear(3, 2, nnx.Rngs(0)) >>> linear_state = nnx.split(linear) ... >>> @jax.jit ... def forward(x, linear_state): ... y, linear_state = nnx.call(linear_state)(x) ... return y, linear_state ... >>> x = jnp.ones((1, 3)) >>> y, linear_state = forward(x, linear_state) >>> y, linear_state = forward(x, linear_state) ... >>> linear = nnx.merge(*linear_state) >>> linear.count.value Array(2, dtype=uint32)
由
call
傳回的 Proxy 物件支援索引和屬性存取,以存取巢狀方法。在下面的範例中,索引increment
方法用於呼叫nodes
字典的b
鍵的StatefulLinear
模組的increment
方法。>>> class StatefulLinear(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32)) ... ... def increment(self): ... self.count += 1 ... ... def __call__(self, x): ... self.increment() ... return x @ self.w + self.b ... >>> rngs = nnx.Rngs(0) >>> nodes = dict( ... a=StatefulLinear(3, 2, rngs), ... b=StatefulLinear(2, 1, rngs), ... ) ... >>> node_state = nnx.split(nodes) >>> # use attribute access >>> _, node_state = nnx.call(node_state)['b'].increment() ... >>> nodes = nnx.merge(*node_state) >>> nodes['a'].count.value Array(0, dtype=uint32) >>> nodes['b'].count.value Array(1, dtype=uint32)
- class flax.nnx.GraphDef[原始碼]#
一個類別,代表 Flax
Module
的所有靜態、無狀態和 Pythonic 部分。可以透過在Module
上呼叫split()
或graphdef()
來產生GraphDef
。
- class flax.nnx.UpdateContext(tag, ref_index, index_ref)[原始碼]#
用於處理複雜狀態更新的上下文管理器。
- split(node, *filters)[原始碼]#
將圖形節點分割成
GraphDef
和一個或多個State`s。 State 是 一個 ``Mapping`
,從字串或整數到Variables
、陣列或巢狀 State。 GraphDef 包含重建Module
圖形所需的所有靜態資訊,它類似於 JAX 的PyTreeDef
。split()
與merge()
結合使用,以在圖形的有狀態和無狀態表示之間無縫切換。使用範例
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> jax.tree.map(jnp.shape, params) State({ 'batch_norm': { 'bias': VariableState( type=Param, value=(2,) ), 'scale': VariableState( type=Param, value=(2,) ) }, 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } }) >>> jax.tree.map(jnp.shape, batch_stats) State({ 'batch_norm': { 'mean': VariableState( type=BatchStat, value=(2,) ), 'var': VariableState( type=BatchStat, value=(2,) ) } })
- flax.nnx.update_context(tag)[原始碼]#
建立一個
UpdateContext
上下文管理器,可用於處理比nnx.update
可以處理的更複雜的狀態更新,包括靜態屬性和圖形結構的更新。UpdateContext 公開一個
split
和merge
API,其簽名與nnx.split
/nnx.merge
相同,但會執行一些簿記,以具有必要的資訊,以便根據轉換內所做的變更完美地更新輸入物件。UpdateContext 必須總共呼叫 split 和 merge 4 次,第一次和最後一次呼叫發生在轉換之外,第二次和第三次呼叫發生在轉換內部,如下圖所示idxmap (2) merge ─────────────────────────────► split (3) ▲ │ │ inside │ │. . . . . . . . . . . . . . . . . . │ index_mapping │ outside │ │ ▼ (1) split──────────────────────────────► merge (4) refmap
第一次呼叫 split
(1)
會建立一個refmap
,用來追蹤外部參照;而第一次呼叫 merge(2)
會建立一個idxmap
,用來追蹤內部參照。第二次呼叫 split(3)
會結合 refmap 和 idxmap 來產生index_mapping
,它會指出外部參照如何對應到內部參照。最後,最後一次呼叫 merge(4)
會使用 index_mapping 和 refmap 來重建轉換的輸出,同時重複使用/更新內部參照。為了避免記憶體洩漏,idxmap 會在(3)
之後清除,而 refmap 會在(4)
之後清除,兩者都會在上下文管理器結束後清除。以下是一個簡單的範例,展示如何使用
update_context
>>> from flax import nnx ... >>> m1 = nnx.Dict({}) >>> with nnx.update_context('example') as ctx: ... graphdef, state = ctx.split(m1) ... @jax.jit ... def f(graphdef, state): ... m2 = ctx.merge(graphdef, state) ... m2.a = 1 ... m2.ref = m2 # create a reference cycle ... return ctx.split(m2) ... graphdef_out, state_out = f(graphdef, state) ... m3 = ctx.merge(graphdef_out, state_out) ... >>> assert m1 is m3 >>> assert m1.a == 1 >>> assert m1.ref is m1
請注意,
update_context
接受一個tag
引數,主要用作安全機制,以減少在使用current_update_context()
存取目前作用中的上下文時,意外使用錯誤 UpdateContext 的風險。current_update_context 可以用作存取目前作用中的上下文的方法,而無需將其作為捕獲傳遞。>>> from flax import nnx ... >>> m1 = nnx.Dict({}) >>> @jax.jit ... def f(graphdef, state): ... ctx = nnx.current_update_context('example') ... m2 = ctx.merge(graphdef, state) ... m2.a = 1 # insert static attribute ... m2.ref = m2 # create a reference cycle ... return ctx.split(m2) ... >>> @nnx.update_context('example') ... def g(m1): ... ctx = nnx.current_update_context('example') ... graphdef, state = ctx.split(m1) ... graphdef_out, state_out = f(graphdef, state) ... return ctx.merge(graphdef_out, state_out) ... >>> m3 = g(m1) >>> assert m1 is m3 >>> assert m1.a == 1 >>> assert m1.ref is m1
如以上程式碼所示,
update_context
也可用作裝飾器,在函式執行期間建立/啟用 UpdateContext 上下文。可以使用current_update_context()
存取上下文。- 參數
tag – 用於識別上下文的字串標籤。
- flax.nnx.current_update_context(tag)[原始碼]#
返回給定標籤的目前作用中的
UpdateContext
。