遷移學習#

此指南展示使用 Flax 執行遷移學習工作流程的各個部分。根據任務需求,可將預訓練模型直接當作特徵萃取器使用或作為較大型模型的一部分進行微調。

本指南說明如何

  • 從 HuggingFace 變壓器 載入預訓練模型,然後從該預訓練模型萃取特定子模組。

  • 建立分類器模型。

  • 將預訓練參數轉移至新的模型結構。

  • 建立最佳化器,使用 Optax 分別訓練模型的不同部分。

  • 設定模型進行訓練。

效能注意事項

根據您的任務需求,本指南中部分內容可能無法達到最佳效果。例如,若只會在預訓練模型上訓練線性分類器,則最好僅萃取特徵嵌入一次,如此一來訓練速度將大幅提升,且可以使用專門用於線性回歸或邏輯分類的演算法。本指南說明如何使用所有模型參數執行遷移學習。


設定#

# Note that the Transformers library doesn't use the latest Flax version.
! pip install -q "transformers[flax]"
# Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support,
# visit https://github.com/google/jax#installation.
! pip install -U -q flax jax jaxlib

建立用於載入模型的函數#

若要載入預訓練分類器,請先建立一個函數,回傳 Flax Module 及其預訓練變數,以利日後使用。

在下方的程式碼中,load_model 函數使用 HuggingFace 變壓器 函式庫中的 FlaxCLIPVisionModel 模型,並萃取 FlaxCLIPModule 模組。

%%capture
from IPython.display import clear_output
from transformers import FlaxCLIPModel

# Note: FlaxCLIPModel is not a Flax Module
def load_model():
  clip = FlaxCLIPModel.from_pretrained('openai/clip-vit-base-patch32')
  clear_output(wait=False) # Clear the loading messages
  module = clip.module # Extract the Flax Module
  variables = {'params': clip.params} # Extract the parameters
  return module, variables

請注意,FlaxCLIPVisionModel 本身並非 Flax Module,這就是需要執行這項額外步驟的原因。

萃取子模組#

呼叫上方程式碼中的 load_model 會回傳 FlaxCLIPModule,其由 text_modelvision_model 子模組組成。

提取 vision_model 子模塊(定義在 .setup() 中)及其變數的一種便捷方法,是在 clip 模塊上緊接著使用 flax.linen.Module.bind ,然後在 vision_model 子模塊上使用 flax.linen.Module.unbind

import flax.linen as nn

clip, clip_variables = load_model()
vision_model, vision_model_vars = clip.bind(clip_variables).vision_model.unbind()

建立分類器#

若要建立分類器,請定義一個新的 Flax Module,它包含一個 backbone (預先訓練的視覺模型)和一個 head (分類器)子模塊。

from typing import Callable
import jax.numpy as jnp
import jax

class Classifier(nn.Module):
  num_classes: int
  backbone: nn.Module
  

  @nn.compact
  def __call__(self, x):
    x = self.backbone(x).pooler_output
    x = nn.Dense(
      self.num_classes, name='head', kernel_init=nn.zeros)(x)
    return x

若要建構分類器 model,請將 vision_model 模塊傳遞給 Classifier 作為 backbone。然後,可以傳遞用於推斷參數形狀的假資料,隨機初始化模型的 params

num_classes = 3
model = Classifier(num_classes=num_classes, backbone=vision_model)

x = jnp.empty((1, 224, 224, 3))
variables = model.init(jax.random.key(1), x)
params = variables['params']

轉移參數#

由於 params 目前是隨機的,因此必須將 vision_model_vars 中的預先訓練參數轉移到 params 結構中的適當位置(即 backbone

params['backbone'] = vision_model_vars['params']

注意:如果模型包含其他可變集合,例如 batch_stats,也必須將它們轉移。

最佳化#

如果你需要分別訓練模型的不同部分,你有三個選項

  1. 使用 stop_gradient

  2. 濾除 jax.grad 的參數。

  3. 針對不同的參數使用多個最佳化器。

對於大多數情況,我們建議使用 Optaxmulti_transform,因為它既有效率,又可以透過許多不同的方法來擴充以實作許多微調策略。

optax.multi_transform#

要使用 optax.multi_transform 必須定義

  1. 參數分區。

  2. 區間與其優化器之間的對應。

  3. 一棵形狀與參數相同的 pytree,其節點包含對應的分區標籤。

要使用 optax.multi_transform 冷凍上述模型的層級,可以使用下列設定

  • 定義 trainablefrozen 參數分區。

  • 對於 trainable 參數,選擇 Adam (optax.adam) 優化器。

  • 對於 frozen 參數,選擇 optax.set_to_zero 優化器。此虛擬優化器將梯度歸零,因此不會進行訓練。

  • 使用 flax.traverse_util.path_aware_map 將參數對應到分區,標註 backbone 中的節點為 frozen,其餘部分為 trainable

from flax import traverse_util
import optax

partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}
param_partitions = traverse_util.path_aware_map(
  lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params)
tx = optax.multi_transform(partition_optimizers, param_partitions)

# visualize a subset of the param_partitions structure
flat = list(traverse_util.flatten_dict(param_partitions).items())
traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:]))
FrozenDict({
    backbone: {
        embeddings: {
            class_embedding: 'frozen',
            patch_embedding: {
                kernel: 'frozen',
            },
        },
    },
    head: {
        bias: 'trainable',
        kernel: 'trainable',
    },
})

要實作 差分學習率optax.set_to_zero 可以替換為其他任何優化器,可以根據任務選擇不同的優化器和分區配置。有關進階優化器,請參閱 Optax 的 組合優化器 文件。

建立 TrainState#

定義模組、參數和優化器後,即可照常建構 TrainState

from flax.training.train_state import TrainState

state = TrainState.create(
  apply_fn=model.apply,
  params=params,
  tx=tx)

由於優化器會處理冷凍或微調策略,train_step 不需要額外變更,訓練才能正常進行。