FP8使用者指南#

JAX 支援各種 FP8 格式,包括 E4M3 (jnp.float8_e4m3fn) 和 E5M2 (jnp.float8_e5m2)。由於 FP8 資料類型的範圍有限,因此必須將高精度的資料縮放以符合 FP8 可表示的範圍,這是一個稱為量化 (Q) 的過程。相反地,去量化 (DQ) 會將 FP8 資料擴充回其原始類型。

儘管 jnp.dot 支援 FP8 輸入,但是某些限制使其不適用於真實世界的應用。作為替代方案,我們的編譯器 XLA 可以識別類似->DQ->Dot 的模式,並隨後調用 FP8 後端 (例如,GPU 的 cublasLt)。FLAX 將這些模式封裝到 nn.fp8_ops.Fp8DotGeneralOp 模組中,使用戶可以輕鬆地將其配置到現有的層 (例如 nn.Dense)。

本教學課程將逐步指導您如何使用它的基礎知識。

設定我們的環境#

在此,我們提供設定筆記本環境所需的程式碼。此外,我們定義一個函數,用於檢查 XLA 最佳化的 HLO 是否確實會在後台呼叫 FP8 點運算。

注意:本教學課程依賴於 XLA-FP8 功能,而此功能僅在 NVIDIA Hopper GPU 或更新版本上受支援。

import flax
import jax
import re
import pprint
from jax import random
from jax import numpy as jnp
from jax._src import test_util as jtu
from flax import linen as nn
from flax.linen import fp8_ops

e4m3 = jnp.float8_e4m3fn
e5m2 = jnp.float8_e5m2
f32 = jnp.float32
E4M3_MAX = jnp.finfo(e4m3).max.astype(f32)

assert jtu.is_cuda_compute_capability_at_least("9.0")

def check_fp8_call(lowered):
  hlo = lowered.compile()
  if re.search(r"custom-call\(f8e4m3fn.*, f8e4m3fn.*", hlo.as_text()):
    print("Fp8 call detected!")
  else:
    print("No Fp8 call!")

FLAX 低階層 API#

JAX 點運算 (例如 jnp.dot) 支援 FP8 資料類型輸入。因此,執行下列呼叫是合法的

key = random.key(0)
A = random.uniform(key, (16, 32))
B = random.uniform(key, (32, 64))
@jax.jit
def dot_fp8(A, B):
  return jnp.dot(A.astype(e4m3), B.astype(e4m3), preferred_element_type=f32)
check_fp8_call(dot_fp8.lower(A, B))

但是,此方法有兩個主要問題。首先,jnp.dot 不接受運算元的縮放因子,其縮放因子預設為 1.0。其次,它不支援混合 FP8 資料類型的運算元。例如,當運算元為 E5M2 和 E4M3 時,會使用已提升的 FP16 資料類型執行點積。

在真實場景中,指定縮放因子(無論是訓練期間的校正或使用者的演算法)至關重要。此外,習慣上會使用 E5M2 進行梯度,而使用 E4M3 進行活化和內核。這些限制使得此方法對於實際應用而言不太實用。

為了克服這些限制並建立更通用 FP8 點積,我們建議利用 XLA-FP8。我們從一個簡單的縮放策略開始。

目前縮放#

縮放因子通常定義為 scale = amax(x) / MAX,其中 amax 是用於找出張量絕對最大值的運算,而 MAX 是目標資料類型可表示範圍的最大值。此縮放方法使我們能夠直接從點積的目前運算元張量導出縮放因子。

@jax.jit
def dot_fp8(A, B):
  A_scale = jnp.max(jnp.abs(A)) / E4M3_MAX
  B_scale = jnp.max(jnp.abs(B)) / E4M3_MAX
  A = fp8_ops.quantize_dequantize(A, e4m3, A_scale, f32)
  B = fp8_ops.quantize_dequantize(B, e4m3, B_scale, f32)

  C = jnp.dot(A, B)
  return C

C = dot_fp8(A, B)
check_fp8_call(dot_fp8.lower(A, B))

如程式碼所示,我們對 dot 乘積的運算項執行假量化 (fp8_ops.quantize_dequantize)。雖然 jnp.dot 仍處理較高精度的輸入,XLA 會偵測此模式並將 dot 運算重寫為 FP8 dot 呼叫(例如針對 GPU 的 cublasLt 呼叫)。此方法有效地模擬第一個範例,但提供更佳的彈性。我們可以控制輸入資料類型(兩者都在這裡設定為 E4M3,但我們可以使用混合 E4M3 和 E5M2)並定義 XLA 可以偵測並在 dot 後端使用的縮放因子。

目前縮放方法的一個主要問題是計算 A_scaleB_scale 產生的額外負擔,這需要額外載入運算項張量。為了克服這個問題,我們建議使用延遲縮放。

延遲縮放#

在延遲縮放中,我們使用與 amax 歷史記錄相關聯的縮放因子。縮放因子仍然是純量,但 amax 歷史記錄是一個儲存來自最近步驟(例如 1024 個步驟)的 amax 值的清單。兩個張量都是由前一個步驟計算的,並保存在模型參數中。

延遲縮放的假量化由 fp8_ops.in_qdq(針對啟用及權重)和 fp8_ops.out_qdq(針對梯度)提供。

a_scale = jnp.array(1.0)
b_scale = jnp.array(1.0)
g_scale = jnp.array(1.0)
a_amax_hist = jnp.zeros((1024,))
b_amax_hist = jnp.zeros((1024,))
g_amax_hist = jnp.zeros((1024,))

@jax.jit
def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist,
            g_scale, g_amax_hist):
  a = fp8_ops.in_qdq(f32, e4m3, a, a_scale, a_amax_hist)
  b = fp8_ops.in_qdq(f32, e4m3, b, b_scale, b_amax_hist)
  
  c = jnp.dot(a, b)
  c = fp8_ops.out_qdq(f32, e5m2, c, g_scale, g_amax_hist)
  return c

C = dot_fp8(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,
            g_scale, g_amax_hist)
check_fp8_call(dot_fp8.lower(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,
                             g_scale, g_amax_hist))

在此範例中,我們首先準備三組縮放因子和 amax 歷史記錄,將它們視為從前一個步驟計算的結果。然後,我們將 fp8_ops.in_qdq 套用於輸入 jnp.dot 的運算項,然後將 fp8_ops.out_qdq 套用於 jnp.dot 的輸出。請注意 fp8_ops.out_qdq 將透過 custom_vjp 函數將假量化套用於輸出的梯度。新的縮放因子和 amax 歷史記錄將透過其梯度傳回,這將在下一個部分中說明。

FLAX 高階 API#

透過 FLAX 函式庫,將 FP8 運算納入現有的 FLAX 層是一個順暢的流程。使用者不需要處理低階量化 API。他們可以用一種簡單的「程式碼注入」方法,將提供的自訂 FP8 dot (fp8_ops.Fp8DotGeneralOp) 整合到 FLAX 層。這個自訂操作匯集了所有 FP8 相關任務,包括量化-去量化運算的配置、縮放因子更新演算法,以及用於正向和反向傳遞的 FP8 資料類型組合選取。

考量下列範例

model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneralOp)
params = model.init(key, A)

@jax.jit
def train_step(var, a): 
  c = model.apply(var, a)
  return jnp.sum(c)

check_fp8_call(train_step.lower(params, A))

在此範例中,我們只需設定 dot_general_cls=fp8_ops.Fp8DotGeneralOp 來啟用 Dense 層使用 FP8 點運算。模型使用方法幾乎與先前相同。主要差異在於新增一種類別的參數:縮放係數組和 amax 記錄。在下一節中,我們會探討如何更新這些參數。

處理 FP8 參數#

我們首先檢查 params 的資料結構。在以下程式碼中,我們將參數值塗黑,然後顯示 PyTree 結構。

params_structure = flax.core.unfreeze(params).copy()
params_structure = flax.traverse_util.flatten_dict(params_structure, sep='/')
for key, value in params_structure.items():
    params_structure[key] = '*'
params_structure = flax.traverse_util.unflatten_dict(params_structure, sep='/')
pprint.pprint(params_structure)

輸出如下

{'_overwrite_with_gradient': {'Fp8DotGeneralOp_0': {'input_amax_history': '*',
                                                    'input_scale': '*',
                                                    'kernel_amax_history': '*',
                                                    'kernel_scale': '*',
                                                    'output_grad_amax_history': '*',
                                                    'output_grad_scale': '*'}},
 'params': {'bias': '*', 'kernel': '*'}}

除了預期的 params 之外,還有一個稱為 _overwrite_with_gradient 的額外類別。此類別包含三組分別針對啟用、核子和點梯度的 amax_historyscale

更新 FP8 參數的梯度#

現在,我們執行一個訓練步驟來取得梯度,並了解如何使用它們來更新參數。

step_fn = jax.jit(jax.grad(train_step, (0, 1)))

grads = step_fn(params, A)

params = flax.core.unfreeze(params)
params = flax.traverse_util.flatten_dict(params, sep='/')
grads = flax.traverse_util.flatten_dict(grads[0], sep='/')

for key, value in params.items():
  if key.startswith('params'):
    params[key] = value + 0.01 * grads[key]
  if key.startswith('_overwrite_with_gradient'):
    params[key] = grads[key]

params = flax.traverse_util.unflatten_dict(params, sep='/')
params = flax.core.freeze(params)

以上程式碼示範如何更新 params_overwrite_with_gradient。對於 params,我們使用公式 new_param = old_param + 0.01 * grads,其中 0.01 是學習率(或使用者可以使用 optax 中的任何最佳化器)。對於 _overwrite_with_gradient,我們只需使用梯度來覆寫舊值。

請注意,flax.training.train_state.TrainState 方便地支援 _overwrite_with_gradient 的類別,所以如果使用者沒有使用自訂 TrainState,則不需要變更他們的腳本。

累計 FP8 參數的梯度#

同一個參數使用在多分支時,autograd 機制會將這些分支的 gradient 加總。這在像 pipeline parallelism 等的情況很常見,每個微批次使用同一組參數做 minibatch。然而,對於 _overwrite_with_gradient 參數,這種透過加總的累積沒有意義。相反地,我們偏好透過取最大值來進行客製累積。

為了處理這一點,我們引入一個自定義數據型態 fp8_ops.fp32_max_grad。以下展示基本的用法

fmax32 = fp8_ops.fp32_max_grad

def reuse_fp8_param(x, y, scale, amax_history):
  scale = scale.astype(fmax32)
  amax_history = amax_history.astype(fmax32)

  x = fp8_ops.in_qdq(f32, e4m3, x, scale, amax_history)
  y = fp8_ops.in_qdq(f32, e4m3, y, scale, amax_history)
  return x + y

reuse_fp8_param_fn = jax.grad(reuse_fp8_param, (0, 1, 2, 3))
reuse_fp8_param_fn = jax.jit(reuse_fp8_param_fn)

_, _, new_ah, new_sf = reuse_fp8_param_fn(2.0, 3.0, a_scale, a_amax_hist)
print(new_ah, new_sf)

在此範例中,我們首先將 scaleamax_history 轉換成 fp8_ops.fp32_max_grad,然後使用同一組 scaleamax_history 呼叫 fp8_ops.in_qdq 兩次。在 autograd 過程中,我們會取各個分支的 gradient 的最大值,得到以下正確結果

1.0 [3. 0. 0. ... 0. 0. 0.]

如果我們不執行類型轉換,則會得到以下結果,代表兩個分支的 gradient 加總了

2.0 [5. 0. 0. ... 0. 0. 0.]

如果使用者選擇使用高級 API,此轉換已包含在其中。