提取中間值#

這份指南會展示如何從模組中萃取中間值。接下來讓我們從這個使用 nn.compact 的簡易 CNN 開始。

from flax import linen as nn
import jax
import jax.numpy as jnp
from typing import Sequence

class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

由於這個模組使用 nn.compact,所以我們無法直接存取中間值。有幾種方法可以取得中間值

在新的變數集合中儲存中間值#

CNN 可以透過呼叫 sow 來參數化,以儲存中間值,如下所示

class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten

    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x
class SowCNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    self.sow('intermediates', 'features', x)
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

sow 在變數集合不變動時的作用形同 No-Op。因此,它非常適合於除錯以及中間值的可選追蹤。'intermediates' 集合也會被 capture_intermediates API 使用 (請見 Use capture_intermediates 區段)。

請注意,預設 sow 會在每次呼叫時新增值

  • 這是必要的,因為模組一旦實例化後,就可以在其父級模組中呼叫多次,而且我們希望擷取到所有已播種的值。

  • 因此您要確保您不會將中間值回饋到 variables。否則每個呼叫都會增加該 tuple 的長度並觸發重新編譯。

  • 若要覆寫預設的追加行為,請指定 init_fnreduce_fn - 詳見 Module.sow()

class SowCNN2(nn.Module):
  @nn.compact
  def __call__(self, x):
    mod = SowCNN(name='SowCNN')
    return mod(x) + mod(x)  # Calling same module instance twice.

@jax.jit
def init(key, x):
  variables = SowCNN2().init(key, x)
  # By default the 'intermediates' collection is not mutable during init.
  # So variables will only contain 'params' here.
  return variables

@jax.jit
def predict(variables, x):
  # If mutable='intermediates' is not specified, then .sow() acts as a noop.
  output, mod_vars = SowCNN2().apply(variables, x, mutable='intermediates')
  features = mod_vars['intermediates']['SowCNN']['features']
  return output, features

batch = jnp.ones((1,28,28,1))
variables = init(jax.random.key(0), batch)
preds, feats = predict(variables, batch)

assert len(feats) == 2  # Tuple with two values since module was called twice.

將模組重構為子模組#

這是個對您清楚要如何分割子模組的狀況十分有用的模式。您在 setup 中公開的任何子模組都可以直接使用。在極限情況下,您可以在 setup 定義所有子模組,並完全避免使用 nn.compact

class RefactoredCNN(nn.Module):
  def setup(self):
    self.features = Features()
    self.classifier = Classifier()

  def __call__(self, x):
    x = self.features(x)
    x = self.classifier(x)
    return x

class Features(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    return x

class Classifier(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

@jax.jit
def init(key, x):
  variables = RefactoredCNN().init(key, x)
  return variables['params']

@jax.jit
def features(params, x):
  return RefactoredCNN().apply({"params": params}, x,
    method=lambda module, x: module.features(x))

params = init(jax.random.key(0), batch)

features(params, batch)

使用 capture_intermediates#

Linen 支援自動擷取子模組的回傳中間值,無需任何程式碼變更。這個模式應被視為擷取中間值的「大榔頭」方法。它作為除錯和檢查工具時非常有用,但本指南中說明的其他模式能讓您更精細地控制要萃取的中間值。

下列程式碼範例中,我們檢查是否有任何中間活化非有限 (NaN 或無限)

@jax.jit
def init(key, x):
  variables = CNN().init(key, x)
  return variables

@jax.jit
def predict(variables, x):
  y, state = CNN().apply(variables, x, capture_intermediates=True, mutable=["intermediates"])
  intermediates = state['intermediates']
  fin = jax.tree_util.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates)
  return y, fin

variables = init(jax.random.key(0), batch)
y, is_finite = predict(variables, batch)
all_finite = all(jax.tree_util.tree_leaves(is_finite))
assert all_finite, "non-finite intermediate detected!"

預設情況下,只會收集 __call__ 方法的中間值。另外,您可以根據 Module 實例和方法名稱傳遞自訂的濾器函數。

filter_Dense = lambda mdl, method_name: isinstance(mdl, nn.Dense)
filter_encodings = lambda mdl, method_name: method_name == "encode"

y, state = CNN().apply(variables, batch, capture_intermediates=filter_Dense, mutable=["intermediates"])
dense_intermediates = state['intermediates']

請注意,capture_intermediates 只會套用在層上。您可以使用 self.sow 手動儲存非層中間值,但濾器函數不會套用在其中。

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    a = nn.Dense(4)(x) # Dense_0
    b = nn.Dense(4)(x) # Dense_1
    c = a + b # not a Flax layer, so won't be stored as an intermediate
    d = nn.Dense(4)(c) # Dense_2
    return d

@jax.jit
def init(key, x):
  variables = Model().init(key, x)
  return variables['params']

@jax.jit
def predict(params, x):
  return Model().apply({"params": params}, x, capture_intermediates=True)

batch = jax.random.uniform(jax.random.key(1), (1,3))
params = init(jax.random.key(0), batch)
preds, feats = predict(params, batch)
feats # intermediate c in Model was not stored because it's not a Flax layer
class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    a = nn.Dense(4)(x) # Dense_0
    b = nn.Dense(4)(x) # Dense_1
    c = a + b
    self.sow('intermediates', 'c', c) # store intermediate c
    d = nn.Dense(4)(c) # Dense_2
    return d

@jax.jit
def init(key, x):
  variables = Model().init(key, x)
  return variables['params']

@jax.jit
def predict(params, x):
  # filter specifically for only the Dense_0 and Dense_2 layer
  filter_fn = lambda mdl, method_name: isinstance(mdl.name, str) and (mdl.name in {'Dense_0', 'Dense_2'})
  return Model().apply({"params": params}, x, capture_intermediates=filter_fn)

batch = jax.random.uniform(jax.random.key(1), (1,3))
params = init(jax.random.key(0), batch)
preds, feats = predict(params, batch)
feats # intermediate c in Model is stored and isn't filtered out by the filter function

如果要從 self.sow 擷取的中間值中,分離出從 capture_intermediates 擷取的中間值,我們可以定義一個獨立的收集,例如 self.sow('sow_intermediates', 'c', c),或在呼叫 .apply() 後,手動過濾出中間值。例如

flattened_dict = flax.traverse_util.flatten_dict(feats['intermediates'], sep='/')
flattened_dict['c']

在效率方面,只要一切都是 JIT 的,則任何您最終不會使用到的中間值都應該會由 XLA 最佳化掉。

使用 Sequential#

您也可以使用 Sequential 組合器的簡單實作來定義 CNN (這在狀態變動更多的方法中很常見)。這可能有助於非常簡單的模型,並讓您可以任意處理模型。但這會非常受限 - 如果您甚至想加入一個條件式,您會被迫從 Sequential 中重新編排,並更明確地架構您的模型。

class Sequential(nn.Module):
  layers: Sequence[nn.Module]

  def __call__(self, x):
    for layer in self.layers:
      x = layer(x)
    return x

def SeqCNN():
  return Sequential([
    nn.Conv(features=32, kernel_size=(3, 3)),
    nn.relu,
    lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
    nn.Conv(features=64, kernel_size=(3, 3)),
    nn.relu,
    lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
    lambda x: x.reshape((x.shape[0], -1)),  # flatten
    nn.Dense(features=256),
    nn.relu,
    nn.Dense(features=10),
    nn.log_softmax,
  ])

@jax.jit
def init(key, x):
  variables = SeqCNN().init(key, x)
  return variables['params']

@jax.jit
def features(params, x):
  return Sequential(SeqCNN().layers[0:7]).apply({"params": params}, x)

batch = jnp.ones((1,28,28,1))
params = init(jax.random.key(0), batch)
features(params, batch)

擷取中間值的梯度#

出於除錯目的,擷取中間值的梯度會很有用。這可以用 Module.perturb() 方法對所需的變數進行擾動來執行。

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.relu(nn.Dense(8)(x))
    x = self.perturb('hidden', x)
    x = nn.Dense(2)(x)
    x = self.perturb('logits', x)
    return x

perturb 會預設將變數加入 perturbations 集合,它的行為就像身分函數,而擾動的梯度會符合輸入的梯度。若要取得擾動,只需初始化模型

x = jnp.empty((1, 4)) # random data
y = jnp.empty((1, 2)) # random data

model = Model()
variables = model.init(jax.random.key(1), x)
params, perturbations = variables['params'], variables['perturbations']

最後計算損失的梯度和擾動有關,這些將匹配中間值的梯度

def loss_fn(params, perturbations, x, y):
  y_pred = model.apply({'params': params, 'perturbations': perturbations}, x)
  return jnp.mean((y_pred - y) ** 2)

intermediate_grads = jax.grad(loss_fn, argnums=1)(params, perturbations, x, y)