指標#

class flax.nnx.metrics.Metric(*args, **kwargs)#

指標的基底類別。任何繼承 Metric 的類別都應實作 computeresetupdate 方法。

__init__()#
compute()#

計算並傳回 Metric 的值。

reset()#

就地重設 Metric

update(**kwargs)#

就地更新 Metric

class flax.nnx.metrics.Average(*args, **kwargs)#

平均指標。

範例用法

>>> import jax.numpy as jnp
>>> from flax import nnx

>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])

>>> metrics = nnx.metrics.Average()
>>> metrics.compute()
Array(nan, dtype=float32)
>>> metrics.update(values=batch_loss)
>>> metrics.compute()
Array(2.5, dtype=float32)
>>> metrics.update(values=batch_loss2)
>>> metrics.compute()
Array(2., dtype=float32)
>>> metrics.reset()
>>> metrics.compute()
Array(nan, dtype=float32)
__init__(argname='values')#

傳入一個字串,表示 update() 將用來取得新值的關鍵字引數。例如,將指標建構為 avg = Average('test'),您可以透過 avg.update(test=new_value) 進行更新。

參數

argname – 一個可選的字串,表示 update() 將用來取得新值的關鍵字引數。預設為 'values'

compute()#

計算並傳回平均值。

reset()#

重設此 Metric

update(**kwargs)#

就地更新此 Metric。此方法將使用 kwargs[self.argname] 中的值來更新指標,其中 self.argname 是在建構時定義的。

參數

**kwargs – 包含 self.argname 條目的關鍵字引數,該條目會對應到我們想要用來更新此指標的值。

class flax.nnx.metrics.Accuracy(*args, **kwargs)#

準確度指標。此指標會繼承 Average,因此它們會共用相同的 resetcompute 方法實作。與 Average 不同,在建構期間,不需要將字串傳遞給 Accuracy

範例用法

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
>>> labels = jnp.array([1, 1, 0, 1, 0])
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
>>> labels2 = jnp.array([0, 1, 1, 1, 1])

>>> metrics = nnx.metrics.Accuracy()
>>> metrics.compute()
Array(nan, dtype=float32)
>>> metrics.update(logits=logits, labels=labels)
>>> metrics.compute()
Array(0.6, dtype=float32)
>>> metrics.update(logits=logits2, labels=labels2)
>>> metrics.compute()
Array(0.7, dtype=float32)
>>> metrics.reset()
>>> metrics.compute()
Array(nan, dtype=float32)
update(*, logits, labels, **_)#

就地更新此 Metric

參數
  • logits – 輸出的預測激活值。這些值會先進行 argmax 運算 (在尾隨維度上),然後再將它們與標籤進行比較。

  • labels – 真實的整數標籤。

class flax.nnx.metrics.Welford(*args, **kwargs)#

使用 Welford 演算法來計算資料串流的平均值和變異數。

範例用法

>>> import jax.numpy as jnp
>>> from flax import nnx

>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])

>>> metrics = nnx.metrics.Welford()
>>> metrics.compute()
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
>>> metrics.update(values=batch_loss)
>>> metrics.compute()
Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
>>> metrics.update(values=batch_loss2)
>>> metrics.compute()
Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
>>> metrics.reset()
>>> metrics.compute()
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
__init__(argname='values')#

傳入一個字串,表示 update() 將用來取得新值的關鍵字引數。例如,將指標建構為 wf = Welford('test'),您可以透過 wf.update(test=new_value) 進行更新。

參數

argname – 一個可選的字串,表示 update() 將用來取得新值的關鍵字引數。預設為 'values'

compute()#

計算並以 Statistics 資料類別物件的形式傳回平均值和變異數統計資料。

reset()#

重設此 Metric

update(**kwargs)#

就地更新此 Metric。此方法將使用 kwargs[self.argname] 中的值來更新指標,其中 self.argname 是在建構時定義的。

參數

**kwargs – 包含 self.argname 條目的關鍵字引數,該條目會對應到我們想要用來更新此指標的值。

class flax.nnx.metrics.MultiMetric(*args, **kwargs)#

MultiMetric 類別用於儲存多個指標,並在單次呼叫中更新它們。

範例用法

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> metrics = nnx.MultiMetric(
...   accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average()
... )

>>> metrics
MultiMetric(
  accuracy=Accuracy(
    argname='values',
    total=MetricState(
      value=Array(0., dtype=float32)
    ),
    count=MetricState(
      value=Array(0, dtype=int32)
    )
  ),
  loss=Average(
    argname='values',
    total=MetricState(
      value=Array(0., dtype=float32)
    ),
    count=MetricState(
      value=Array(0, dtype=int32)
    )
  )
)

>>> metrics.accuracy
Accuracy(
  argname='values',
  total=MetricState(
    value=Array(0., dtype=float32)
  ),
  count=MetricState(
    value=Array(0, dtype=int32)
  )
)

>>> metrics.loss
Average(
  argname='values',
  total=MetricState(
    value=Array(0., dtype=float32)
  ),
  count=MetricState(
    value=Array(0, dtype=int32)
  )
)

>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
>>> labels = jnp.array([1, 1, 0, 1, 0])
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
>>> labels2 = jnp.array([0, 1, 1, 1, 1])

>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])

>>> metrics.compute()
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
>>> metrics.update(logits=logits, labels=labels, values=batch_loss)
>>> metrics.compute()
{'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
>>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2)
>>> metrics.compute()
{'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)}
>>> metrics.reset()
>>> metrics.compute()
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
__init__(**metrics)#

將關鍵字參數傳遞給建構子,例如 MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)

參數

**metrics – 將用於存取相應 Metric 的關鍵字參數。

compute()#

計算並返回所有底層 Metric 的值。此方法將返回一個字典,將字串(由傳遞給建構子的關鍵字參數 **metrics 定義)映射到相應的指標值。

reset()#

重置所有底層 Metric

update(**updates)#

就地更新此 MultiMetric 中所有底層 Metric。所有 **updates 將會被傳遞給所有底層 Metricupdate 方法。

參數

**updates – 將傳遞給底層 Metricupdate 方法的關鍵字參數。