常見問題集 (FAQ)#

這是一系列對於常見問題 (FAQ) 的回答。你可以透過在 GitHub 討論區 開始一個新的主題讓 Flax FAQ 更充實。

如何取中間值的導數(使用 Module.perturb)?#

若要取模型層中隱藏/中間活化對輸出的導數或梯度,你可以使用 flax.linen.Module.perturb()。在正向傳遞中定義零值 flax.linen.Module “擾動” 參數 — perturb(...) — 具有與中間活化相同形狀,將損失函數定義為加上獨立論證 'perturbations',對擾動論證使用 jax.grad 執行 JAX 導數運算。

詳全文檔範例和文件,請前往

Flax Linen remat_scan() 是否與 scan(remat(...)) 相同?#

Flax remat_scan() (flax.linen.remat_scan()) 和 scan(remat(...)) (flax.linen.scan()flax.linen.remat()) 不相同,而且 remat_scan() 僅支援特定情況。也就是說,remat_scan() 將輸入和輸出視為進位(在訓練迴圈中傳送的隱藏狀態)。建議使用 scan(remat(...)),因為通常需要額外的參數,例如 in_axes(輸入陣列軸)或 out_axes(輸出陣列軸),而 flax.linen.remat_scan() 沒有公開這些參數。