指標#
- class flax.nnx.metrics.Metric(*args, **kwargs)#
指標的基底類別。任何繼承
Metric
的類別都應實作compute
、reset
和update
方法。- __init__()#
- compute()#
計算並傳回
Metric
的值。
- reset()#
就地重設
Metric
。
- update(**kwargs)#
就地更新
Metric
。
- class flax.nnx.metrics.Average(*args, **kwargs)#
平均指標。
範例用法
>>> import jax.numpy as jnp >>> from flax import nnx >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) >>> metrics = nnx.metrics.Average() >>> metrics.compute() Array(nan, dtype=float32) >>> metrics.update(values=batch_loss) >>> metrics.compute() Array(2.5, dtype=float32) >>> metrics.update(values=batch_loss2) >>> metrics.compute() Array(2., dtype=float32) >>> metrics.reset() >>> metrics.compute() Array(nan, dtype=float32)
- __init__(argname='values')#
傳入一個字串,表示
update()
將用來取得新值的關鍵字引數。例如,將指標建構為avg = Average('test')
,您可以透過avg.update(test=new_value)
進行更新。- 參數
argname – 一個可選的字串,表示
update()
將用來取得新值的關鍵字引數。預設為'values'
。
- compute()#
計算並傳回平均值。
- reset()#
重設此
Metric
。
- update(**kwargs)#
就地更新此
Metric
。此方法將使用kwargs[self.argname]
中的值來更新指標,其中self.argname
是在建構時定義的。- 參數
**kwargs – 包含
self.argname
條目的關鍵字引數,該條目會對應到我們想要用來更新此指標的值。
- class flax.nnx.metrics.Accuracy(*args, **kwargs)#
準確度指標。此指標會繼承
Average
,因此它們會共用相同的reset
和compute
方法實作。與Average
不同,在建構期間,不需要將字串傳遞給Accuracy
。範例用法
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) >>> labels = jnp.array([1, 1, 0, 1, 0]) >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2)) >>> labels2 = jnp.array([0, 1, 1, 1, 1]) >>> metrics = nnx.metrics.Accuracy() >>> metrics.compute() Array(nan, dtype=float32) >>> metrics.update(logits=logits, labels=labels) >>> metrics.compute() Array(0.6, dtype=float32) >>> metrics.update(logits=logits2, labels=labels2) >>> metrics.compute() Array(0.7, dtype=float32) >>> metrics.reset() >>> metrics.compute() Array(nan, dtype=float32)
- update(*, logits, labels, **_)#
就地更新此
Metric
。- 參數
logits – 輸出的預測激活值。這些值會先進行 argmax 運算 (在尾隨維度上),然後再將它們與標籤進行比較。
labels – 真實的整數標籤。
- class flax.nnx.metrics.Welford(*args, **kwargs)#
使用 Welford 演算法來計算資料串流的平均值和變異數。
範例用法
>>> import jax.numpy as jnp >>> from flax import nnx >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) >>> metrics = nnx.metrics.Welford() >>> metrics.compute() Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32)) >>> metrics.update(values=batch_loss) >>> metrics.compute() Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32)) >>> metrics.update(values=batch_loss2) >>> metrics.compute() Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32)) >>> metrics.reset() >>> metrics.compute() Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
- __init__(argname='values')#
傳入一個字串,表示
update()
將用來取得新值的關鍵字引數。例如,將指標建構為wf = Welford('test')
,您可以透過wf.update(test=new_value)
進行更新。- 參數
argname – 一個可選的字串,表示
update()
將用來取得新值的關鍵字引數。預設為'values'
。
- compute()#
計算並以
Statistics
資料類別物件的形式傳回平均值和變異數統計資料。
- reset()#
重設此
Metric
。
- update(**kwargs)#
就地更新此
Metric
。此方法將使用kwargs[self.argname]
中的值來更新指標,其中self.argname
是在建構時定義的。- 參數
**kwargs – 包含
self.argname
條目的關鍵字引數,該條目會對應到我們想要用來更新此指標的值。
- class flax.nnx.metrics.MultiMetric(*args, **kwargs)#
MultiMetric 類別用於儲存多個指標,並在單次呼叫中更新它們。
範例用法
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> metrics = nnx.MultiMetric( ... accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average() ... ) >>> metrics MultiMetric( accuracy=Accuracy( argname='values', total=MetricState( value=Array(0., dtype=float32) ), count=MetricState( value=Array(0, dtype=int32) ) ), loss=Average( argname='values', total=MetricState( value=Array(0., dtype=float32) ), count=MetricState( value=Array(0, dtype=int32) ) ) ) >>> metrics.accuracy Accuracy( argname='values', total=MetricState( value=Array(0., dtype=float32) ), count=MetricState( value=Array(0, dtype=int32) ) ) >>> metrics.loss Average( argname='values', total=MetricState( value=Array(0., dtype=float32) ), count=MetricState( value=Array(0, dtype=int32) ) ) >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) >>> labels = jnp.array([1, 1, 0, 1, 0]) >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2)) >>> labels2 = jnp.array([0, 1, 1, 1, 1]) >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)} >>> metrics.update(logits=logits, labels=labels, values=batch_loss) >>> metrics.compute() {'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)} >>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2) >>> metrics.compute() {'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)} >>> metrics.reset() >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
- __init__(**metrics)#
將關鍵字參數傳遞給建構子,例如
MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)
。- 參數
**metrics – 將用於存取相應
Metric
的關鍵字參數。
- compute()#
計算並返回所有底層
Metric
的值。此方法將返回一個字典,將字串(由傳遞給建構子的關鍵字參數**metrics
定義)映射到相應的指標值。
- reset()#
重置所有底層
Metric
。
- update(**updates)#
就地更新此
MultiMetric
中所有底層Metric
。所有**updates
將會被傳遞給所有底層Metric
的update
方法。- 參數
**updates – 將傳遞給底層
Metric
的update
方法的關鍵字參數。