遷移學習#
此指南展示使用 Flax 執行遷移學習工作流程的各個部分。根據任務需求,可將預訓練模型直接當作特徵萃取器使用或作為較大型模型的一部分進行微調。
本指南說明如何
從 HuggingFace 變壓器 載入預訓練模型,然後從該預訓練模型萃取特定子模組。
建立分類器模型。
將預訓練參數轉移至新的模型結構。
建立最佳化器,使用 Optax 分別訓練模型的不同部分。
設定模型進行訓練。
效能注意事項
根據您的任務需求,本指南中部分內容可能無法達到最佳效果。例如,若只會在預訓練模型上訓練線性分類器,則最好僅萃取特徵嵌入一次,如此一來訓練速度將大幅提升,且可以使用專門用於線性回歸或邏輯分類的演算法。本指南說明如何使用所有模型參數執行遷移學習。
設定#
# Note that the Transformers library doesn't use the latest Flax version.
! pip install -q "transformers[flax]"
# Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support,
# visit https://github.com/google/jax#installation.
! pip install -U -q flax jax jaxlib
建立用於載入模型的函數#
若要載入預訓練分類器,請先建立一個函數,回傳 Flax Module
及其預訓練變數,以利日後使用。
在下方的程式碼中,load_model
函數使用 HuggingFace 變壓器 函式庫中的 FlaxCLIPVisionModel
模型,並萃取 FlaxCLIPModule
模組。
%%capture
from IPython.display import clear_output
from transformers import FlaxCLIPModel
# Note: FlaxCLIPModel is not a Flax Module
def load_model():
clip = FlaxCLIPModel.from_pretrained('openai/clip-vit-base-patch32')
clear_output(wait=False) # Clear the loading messages
module = clip.module # Extract the Flax Module
variables = {'params': clip.params} # Extract the parameters
return module, variables
請注意,FlaxCLIPVisionModel
本身並非 Flax Module
,這就是需要執行這項額外步驟的原因。
萃取子模組#
呼叫上方程式碼中的 load_model
會回傳 FlaxCLIPModule
,其由 text_model
和 vision_model
子模組組成。
提取 vision_model
子模塊(定義在 .setup()
中)及其變數的一種便捷方法,是在 clip
模塊上緊接著使用 flax.linen.Module.bind
,然後在 vision_model
子模塊上使用 flax.linen.Module.unbind
。
import flax.linen as nn
clip, clip_variables = load_model()
vision_model, vision_model_vars = clip.bind(clip_variables).vision_model.unbind()
建立分類器#
若要建立分類器,請定義一個新的 Flax Module
,它包含一個 backbone
(預先訓練的視覺模型)和一個 head
(分類器)子模塊。
from typing import Callable
import jax.numpy as jnp
import jax
class Classifier(nn.Module):
num_classes: int
backbone: nn.Module
@nn.compact
def __call__(self, x):
x = self.backbone(x).pooler_output
x = nn.Dense(
self.num_classes, name='head', kernel_init=nn.zeros)(x)
return x
若要建構分類器 model
,請將 vision_model
模塊傳遞給 Classifier
作為 backbone
。然後,可以傳遞用於推斷參數形狀的假資料,隨機初始化模型的 params
。
num_classes = 3
model = Classifier(num_classes=num_classes, backbone=vision_model)
x = jnp.empty((1, 224, 224, 3))
variables = model.init(jax.random.key(1), x)
params = variables['params']
轉移參數#
由於 params
目前是隨機的,因此必須將 vision_model_vars
中的預先訓練參數轉移到 params
結構中的適當位置(即 backbone
)
params['backbone'] = vision_model_vars['params']
注意:如果模型包含其他可變集合,例如 batch_stats
,也必須將它們轉移。
最佳化#
如果你需要分別訓練模型的不同部分,你有三個選項
使用
stop_gradient
。濾除
jax.grad
的參數。針對不同的參數使用多個最佳化器。
對於大多數情況,我們建議使用 Optax 的 multi_transform
,因為它既有效率,又可以透過許多不同的方法來擴充以實作許多微調策略。
optax.multi_transform#
要使用 optax.multi_transform
必須定義
參數分區。
區間與其優化器之間的對應。
一棵形狀與參數相同的 pytree,其節點包含對應的分區標籤。
要使用 optax.multi_transform
冷凍上述模型的層級,可以使用下列設定
定義
trainable
和frozen
參數分區。對於
trainable
參數,選擇 Adam (optax.adam
) 優化器。
對於
frozen
參數,選擇optax.set_to_zero
優化器。此虛擬優化器將梯度歸零,因此不會進行訓練。使用
flax.traverse_util.path_aware_map
將參數對應到分區,標註backbone
中的節點為frozen
,其餘部分為trainable
。
from flax import traverse_util
import optax
partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}
param_partitions = traverse_util.path_aware_map(
lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params)
tx = optax.multi_transform(partition_optimizers, param_partitions)
# visualize a subset of the param_partitions structure
flat = list(traverse_util.flatten_dict(param_partitions).items())
traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:]))
FrozenDict({
backbone: {
embeddings: {
class_embedding: 'frozen',
patch_embedding: {
kernel: 'frozen',
},
},
},
head: {
bias: 'trainable',
kernel: 'trainable',
},
})
要實作 差分學習率,optax.set_to_zero
可以替換為其他任何優化器,可以根據任務選擇不同的優化器和分區配置。有關進階優化器,請參閱 Optax 的 組合優化器 文件。
建立 TrainState
#
定義模組、參數和優化器後,即可照常建構 TrainState
from flax.training.train_state import TrainState
state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=tx)
由於優化器會處理冷凍或微調策略,train_step
不需要額外變更,訓練才能正常進行。