提取中間值#
這份指南會展示如何從模組中萃取中間值。接下來讓我們從這個使用 nn.compact
的簡易 CNN 開始。
from flax import linen as nn
import jax
import jax.numpy as jnp
from typing import Sequence
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
由於這個模組使用 nn.compact
,所以我們無法直接存取中間值。有幾種方法可以取得中間值
在新的變數集合中儲存中間值#
CNN 可以透過呼叫 sow
來參數化,以儲存中間值,如下所示
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
class SowCNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
self.sow('intermediates', 'features', x)
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
sow
在變數集合不變動時的作用形同 No-Op。因此,它非常適合於除錯以及中間值的可選追蹤。'intermediates' 集合也會被 capture_intermediates
API 使用 (請見 Use capture_intermediates 區段)。
請注意,預設 sow
會在每次呼叫時新增值
這是必要的,因為模組一旦實例化後,就可以在其父級模組中呼叫多次,而且我們希望擷取到所有已播種的值。
因此您要確保您不會將中間值回饋到
variables
。否則每個呼叫都會增加該 tuple 的長度並觸發重新編譯。若要覆寫預設的追加行為,請指定
init_fn
和reduce_fn
- 詳見Module.sow()
。
class SowCNN2(nn.Module):
@nn.compact
def __call__(self, x):
mod = SowCNN(name='SowCNN')
return mod(x) + mod(x) # Calling same module instance twice.
@jax.jit
def init(key, x):
variables = SowCNN2().init(key, x)
# By default the 'intermediates' collection is not mutable during init.
# So variables will only contain 'params' here.
return variables
@jax.jit
def predict(variables, x):
# If mutable='intermediates' is not specified, then .sow() acts as a noop.
output, mod_vars = SowCNN2().apply(variables, x, mutable='intermediates')
features = mod_vars['intermediates']['SowCNN']['features']
return output, features
batch = jnp.ones((1,28,28,1))
variables = init(jax.random.key(0), batch)
preds, feats = predict(variables, batch)
assert len(feats) == 2 # Tuple with two values since module was called twice.
將模組重構為子模組#
這是個對您清楚要如何分割子模組的狀況十分有用的模式。您在 setup
中公開的任何子模組都可以直接使用。在極限情況下,您可以在 setup
定義所有子模組,並完全避免使用 nn.compact
。
class RefactoredCNN(nn.Module):
def setup(self):
self.features = Features()
self.classifier = Classifier()
def __call__(self, x):
x = self.features(x)
x = self.classifier(x)
return x
class Features(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
return x
class Classifier(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
@jax.jit
def init(key, x):
variables = RefactoredCNN().init(key, x)
return variables['params']
@jax.jit
def features(params, x):
return RefactoredCNN().apply({"params": params}, x,
method=lambda module, x: module.features(x))
params = init(jax.random.key(0), batch)
features(params, batch)
使用 capture_intermediates
#
Linen 支援自動擷取子模組的回傳中間值,無需任何程式碼變更。這個模式應被視為擷取中間值的「大榔頭」方法。它作為除錯和檢查工具時非常有用,但本指南中說明的其他模式能讓您更精細地控制要萃取的中間值。
下列程式碼範例中,我們檢查是否有任何中間活化非有限 (NaN 或無限)
@jax.jit
def init(key, x):
variables = CNN().init(key, x)
return variables
@jax.jit
def predict(variables, x):
y, state = CNN().apply(variables, x, capture_intermediates=True, mutable=["intermediates"])
intermediates = state['intermediates']
fin = jax.tree_util.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates)
return y, fin
variables = init(jax.random.key(0), batch)
y, is_finite = predict(variables, batch)
all_finite = all(jax.tree_util.tree_leaves(is_finite))
assert all_finite, "non-finite intermediate detected!"
預設情況下,只會收集 __call__
方法的中間值。另外,您可以根據 Module
實例和方法名稱傳遞自訂的濾器函數。
filter_Dense = lambda mdl, method_name: isinstance(mdl, nn.Dense)
filter_encodings = lambda mdl, method_name: method_name == "encode"
y, state = CNN().apply(variables, batch, capture_intermediates=filter_Dense, mutable=["intermediates"])
dense_intermediates = state['intermediates']
請注意,capture_intermediates
只會套用在層上。您可以使用 self.sow
手動儲存非層中間值,但濾器函數不會套用在其中。
class Model(nn.Module):
@nn.compact
def __call__(self, x):
a = nn.Dense(4)(x) # Dense_0
b = nn.Dense(4)(x) # Dense_1
c = a + b # not a Flax layer, so won't be stored as an intermediate
d = nn.Dense(4)(c) # Dense_2
return d
@jax.jit
def init(key, x):
variables = Model().init(key, x)
return variables['params']
@jax.jit
def predict(params, x):
return Model().apply({"params": params}, x, capture_intermediates=True)
batch = jax.random.uniform(jax.random.key(1), (1,3))
params = init(jax.random.key(0), batch)
preds, feats = predict(params, batch)
feats # intermediate c in Model was not stored because it's not a Flax layer
class Model(nn.Module):
@nn.compact
def __call__(self, x):
a = nn.Dense(4)(x) # Dense_0
b = nn.Dense(4)(x) # Dense_1
c = a + b
self.sow('intermediates', 'c', c) # store intermediate c
d = nn.Dense(4)(c) # Dense_2
return d
@jax.jit
def init(key, x):
variables = Model().init(key, x)
return variables['params']
@jax.jit
def predict(params, x):
# filter specifically for only the Dense_0 and Dense_2 layer
filter_fn = lambda mdl, method_name: isinstance(mdl.name, str) and (mdl.name in {'Dense_0', 'Dense_2'})
return Model().apply({"params": params}, x, capture_intermediates=filter_fn)
batch = jax.random.uniform(jax.random.key(1), (1,3))
params = init(jax.random.key(0), batch)
preds, feats = predict(params, batch)
feats # intermediate c in Model is stored and isn't filtered out by the filter function
如果要從 self.sow
擷取的中間值中,分離出從 capture_intermediates
擷取的中間值,我們可以定義一個獨立的收集,例如 self.sow('sow_intermediates', 'c', c)
,或在呼叫 .apply()
後,手動過濾出中間值。例如
flattened_dict = flax.traverse_util.flatten_dict(feats['intermediates'], sep='/')
flattened_dict['c']
在效率方面,只要一切都是 JIT 的,則任何您最終不會使用到的中間值都應該會由 XLA 最佳化掉。
使用 Sequential
#
您也可以使用 Sequential
組合器的簡單實作來定義 CNN
(這在狀態變動更多的方法中很常見)。這可能有助於非常簡單的模型,並讓您可以任意處理模型。但這會非常受限 - 如果您甚至想加入一個條件式,您會被迫從 Sequential
中重新編排,並更明確地架構您的模型。
class Sequential(nn.Module):
layers: Sequence[nn.Module]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
def SeqCNN():
return Sequential([
nn.Conv(features=32, kernel_size=(3, 3)),
nn.relu,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
nn.Conv(features=64, kernel_size=(3, 3)),
nn.relu,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
lambda x: x.reshape((x.shape[0], -1)), # flatten
nn.Dense(features=256),
nn.relu,
nn.Dense(features=10),
nn.log_softmax,
])
@jax.jit
def init(key, x):
variables = SeqCNN().init(key, x)
return variables['params']
@jax.jit
def features(params, x):
return Sequential(SeqCNN().layers[0:7]).apply({"params": params}, x)
batch = jnp.ones((1,28,28,1))
params = init(jax.random.key(0), batch)
features(params, batch)
擷取中間值的梯度#
出於除錯目的,擷取中間值的梯度會很有用。這可以用 Module.perturb()
方法對所需的變數進行擾動來執行。
class Model(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.relu(nn.Dense(8)(x))
x = self.perturb('hidden', x)
x = nn.Dense(2)(x)
x = self.perturb('logits', x)
return x
perturb
會預設將變數加入 perturbations
集合,它的行為就像身分函數,而擾動的梯度會符合輸入的梯度。若要取得擾動,只需初始化模型
x = jnp.empty((1, 4)) # random data
y = jnp.empty((1, 2)) # random data
model = Model()
variables = model.init(jax.random.key(1), x)
params, perturbations = variables['params'], variables['perturbations']
最後計算損失的梯度和擾動有關,這些將匹配中間值的梯度
def loss_fn(params, perturbations, x, y):
y_pred = model.apply({'params': params, 'perturbations': perturbations}, x)
return jnp.mean((y_pred - y) ** 2)
intermediate_grads = jax.grad(loss_fn, argnums=1)(params, perturbations, x, y)