常見問題集 (FAQ)#
這是一系列對於常見問題 (FAQ) 的回答。你可以透過在 GitHub 討論區 開始一個新的主題讓 Flax FAQ 更充實。
如何取中間值的導數(使用 Module.perturb
)?#
若要取模型層中隱藏/中間活化對輸出的導數或梯度,你可以使用 flax.linen.Module.perturb()
。在正向傳遞中定義零值 flax.linen.Module
“擾動” 參數 — perturb(...)
— 具有與中間活化相同形狀,將損失函數定義為加上獨立論證 'perturbations'
,對擾動論證使用 jax.grad
執行 JAX 導數運算。
詳全文檔範例和文件,請前往
Flax Linen remat_scan()
是否與 scan(remat(...))
相同?#
Flax remat_scan()
(flax.linen.remat_scan()
) 和 scan(remat(...))
(flax.linen.scan()
在 flax.linen.remat()
) 不相同,而且 remat_scan()
僅支援特定情況。也就是說,remat_scan()
將輸入和輸出視為進位(在訓練迴圈中傳送的隱藏狀態)。建議使用 scan(remat(...))
,因為通常需要額外的參數,例如 in_axes
(輸入陣列軸)或 out_axes
(輸出陣列軸),而 flax.linen.remat_scan()
沒有公開這些參數。
建議的訓練迴圈程式庫有哪些?#
請考慮使用 CLU(常見迴圈工具) google/CommonLoopUtils。若要開始入門,請前往 CLU Synopsis Colab。您可以在 google/flax GitHub 討論 中找到常見的 CLU + Flax 相關問題解答。
查閱 Google 官方 範例,了解如何將訓練迴圈與 (CLU) 指標搭配使用。例如,這個範例是 Flax ImageNet 的 train.py。
對於電腦視覺研究,請考慮 google-research/scenic。Scenic 是一組輕量級共用程式庫,能夠解決訓練大型視覺模型時常見的任務(並提供許多專案範例)。Scenic 是使用 JAX 和 Flax 開發的。若要開始入門,請前往 GitHub 上的 README 頁面。