FP8使用者指南#
JAX 支援各種 FP8 格式,包括 E4M3 (jnp.float8_e4m3fn) 和 E5M2 (jnp.float8_e5m2)。由於 FP8 資料類型的範圍有限,因此必須將高精度的資料縮放以符合 FP8 可表示的範圍,這是一個稱為量化 (Q) 的過程。相反地,去量化 (DQ) 會將 FP8 資料擴充回其原始類型。
儘管 jnp.dot 支援 FP8 輸入,但是某些限制使其不適用於真實世界的應用。作為替代方案,我們的編譯器 XLA 可以識別類似
本教學課程將逐步指導您如何使用它的基礎知識。
設定我們的環境#
在此,我們提供設定筆記本環境所需的程式碼。此外,我們定義一個函數,用於檢查 XLA 最佳化的 HLO 是否確實會在後台呼叫 FP8 點運算。
注意:本教學課程依賴於 XLA-FP8 功能,而此功能僅在 NVIDIA Hopper GPU 或更新版本上受支援。
import flax
import jax
import re
import pprint
from jax import random
from jax import numpy as jnp
from jax._src import test_util as jtu
from flax import linen as nn
from flax.linen import fp8_ops
e4m3 = jnp.float8_e4m3fn
e5m2 = jnp.float8_e5m2
f32 = jnp.float32
E4M3_MAX = jnp.finfo(e4m3).max.astype(f32)
assert jtu.is_cuda_compute_capability_at_least("9.0")
def check_fp8_call(lowered):
hlo = lowered.compile()
if re.search(r"custom-call\(f8e4m3fn.*, f8e4m3fn.*", hlo.as_text()):
print("Fp8 call detected!")
else:
print("No Fp8 call!")
FLAX 低階層 API#
JAX 點運算 (例如 jnp.dot
) 支援 FP8 資料類型輸入。因此,執行下列呼叫是合法的
key = random.key(0)
A = random.uniform(key, (16, 32))
B = random.uniform(key, (32, 64))
@jax.jit
def dot_fp8(A, B):
return jnp.dot(A.astype(e4m3), B.astype(e4m3), preferred_element_type=f32)
check_fp8_call(dot_fp8.lower(A, B))
但是,此方法有兩個主要問題。首先,jnp.dot
不接受運算元的縮放因子,其縮放因子預設為 1.0。其次,它不支援混合 FP8 資料類型的運算元。例如,當運算元為 E5M2 和 E4M3 時,會使用已提升的 FP16 資料類型執行點積。
在真實場景中,指定縮放因子(無論是訓練期間的校正或使用者的演算法)至關重要。此外,習慣上會使用 E5M2 進行梯度,而使用 E4M3 進行活化和內核。這些限制使得此方法對於實際應用而言不太實用。
為了克服這些限制並建立更通用 FP8 點積,我們建議利用 XLA-FP8。我們從一個簡單的縮放策略開始。
目前縮放#
縮放因子通常定義為 scale = amax(x) / MAX
,其中 amax
是用於找出張量絕對最大值的運算,而 MAX
是目標資料類型可表示範圍的最大值。此縮放方法使我們能夠直接從點積的目前運算元張量導出縮放因子。
@jax.jit
def dot_fp8(A, B):
A_scale = jnp.max(jnp.abs(A)) / E4M3_MAX
B_scale = jnp.max(jnp.abs(B)) / E4M3_MAX
A = fp8_ops.quantize_dequantize(A, e4m3, A_scale, f32)
B = fp8_ops.quantize_dequantize(B, e4m3, B_scale, f32)
C = jnp.dot(A, B)
return C
C = dot_fp8(A, B)
check_fp8_call(dot_fp8.lower(A, B))
如程式碼所示,我們對 dot 乘積的運算項執行假量化 (fp8_ops.quantize_dequantize
)。雖然 jnp.dot
仍處理較高精度的輸入,XLA 會偵測此模式並將 dot 運算重寫為 FP8 dot 呼叫(例如針對 GPU 的 cublasLt 呼叫)。此方法有效地模擬第一個範例,但提供更佳的彈性。我們可以控制輸入資料類型(兩者都在這裡設定為 E4M3,但我們可以使用混合 E4M3 和 E5M2)並定義 XLA 可以偵測並在 dot 後端使用的縮放因子。
目前縮放方法的一個主要問題是計算 A_scale
和 B_scale
產生的額外負擔,這需要額外載入運算項張量。為了克服這個問題,我們建議使用延遲縮放。
延遲縮放#
在延遲縮放中,我們使用與 amax 歷史記錄相關聯的縮放因子。縮放因子仍然是純量,但 amax 歷史記錄是一個儲存來自最近步驟(例如 1024 個步驟)的 amax 值的清單。兩個張量都是由前一個步驟計算的,並保存在模型參數中。
延遲縮放的假量化由 fp8_ops.in_qdq
(針對啟用及權重)和 fp8_ops.out_qdq
(針對梯度)提供。
a_scale = jnp.array(1.0)
b_scale = jnp.array(1.0)
g_scale = jnp.array(1.0)
a_amax_hist = jnp.zeros((1024,))
b_amax_hist = jnp.zeros((1024,))
g_amax_hist = jnp.zeros((1024,))
@jax.jit
def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist,
g_scale, g_amax_hist):
a = fp8_ops.in_qdq(f32, e4m3, a, a_scale, a_amax_hist)
b = fp8_ops.in_qdq(f32, e4m3, b, b_scale, b_amax_hist)
c = jnp.dot(a, b)
c = fp8_ops.out_qdq(f32, e5m2, c, g_scale, g_amax_hist)
return c
C = dot_fp8(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,
g_scale, g_amax_hist)
check_fp8_call(dot_fp8.lower(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,
g_scale, g_amax_hist))
在此範例中,我們首先準備三組縮放因子和 amax 歷史記錄,將它們視為從前一個步驟計算的結果。然後,我們將 fp8_ops.in_qdq
套用於輸入 jnp.dot
的運算項,然後將 fp8_ops.out_qdq
套用於 jnp.dot
的輸出。請注意 fp8_ops.out_qdq
將透過 custom_vjp 函數將假量化套用於輸出的梯度。新的縮放因子和 amax 歷史記錄將透過其梯度傳回,這將在下一個部分中說明。
FLAX 高階 API#
透過 FLAX 函式庫,將 FP8 運算納入現有的 FLAX 層是一個順暢的流程。使用者不需要處理低階量化 API。他們可以用一種簡單的「程式碼注入」方法,將提供的自訂 FP8 dot (fp8_ops.Fp8DotGeneralOp
) 整合到 FLAX 層。這個自訂操作匯集了所有 FP8 相關任務,包括量化-去量化運算的配置、縮放因子更新演算法,以及用於正向和反向傳遞的 FP8 資料類型組合選取。
考量下列範例
model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneralOp)
params = model.init(key, A)
@jax.jit
def train_step(var, a):
c = model.apply(var, a)
return jnp.sum(c)
check_fp8_call(train_step.lower(params, A))
在此範例中,我們只需設定 dot_general_cls=fp8_ops.Fp8DotGeneralOp
來啟用 Dense 層使用 FP8 點運算。模型使用方法幾乎與先前相同。主要差異在於新增一種類別的參數:縮放係數組和 amax 記錄。在下一節中,我們會探討如何更新這些參數。
處理 FP8 參數#
我們首先檢查 params
的資料結構。在以下程式碼中,我們將參數值塗黑,然後顯示 PyTree 結構。
params_structure = flax.core.unfreeze(params).copy()
params_structure = flax.traverse_util.flatten_dict(params_structure, sep='/')
for key, value in params_structure.items():
params_structure[key] = '*'
params_structure = flax.traverse_util.unflatten_dict(params_structure, sep='/')
pprint.pprint(params_structure)
輸出如下
{'_overwrite_with_gradient': {'Fp8DotGeneralOp_0': {'input_amax_history': '*',
'input_scale': '*',
'kernel_amax_history': '*',
'kernel_scale': '*',
'output_grad_amax_history': '*',
'output_grad_scale': '*'}},
'params': {'bias': '*', 'kernel': '*'}}
除了預期的 params
之外,還有一個稱為 _overwrite_with_gradient
的額外類別。此類別包含三組分別針對啟用、核子和點梯度的 amax_history
和 scale
。
更新 FP8 參數的梯度#
現在,我們執行一個訓練步驟來取得梯度,並了解如何使用它們來更新參數。
step_fn = jax.jit(jax.grad(train_step, (0, 1)))
grads = step_fn(params, A)
params = flax.core.unfreeze(params)
params = flax.traverse_util.flatten_dict(params, sep='/')
grads = flax.traverse_util.flatten_dict(grads[0], sep='/')
for key, value in params.items():
if key.startswith('params'):
params[key] = value + 0.01 * grads[key]
if key.startswith('_overwrite_with_gradient'):
params[key] = grads[key]
params = flax.traverse_util.unflatten_dict(params, sep='/')
params = flax.core.freeze(params)
以上程式碼示範如何更新 params
和 _overwrite_with_gradient
。對於 params
,我們使用公式 new_param = old_param + 0.01 * grads
,其中 0.01
是學習率(或使用者可以使用 optax
中的任何最佳化器)。對於 _overwrite_with_gradient
,我們只需使用梯度來覆寫舊值。
請注意,flax.training.train_state.TrainState
方便地支援 _overwrite_with_gradient
的類別,所以如果使用者沒有使用自訂 TrainState
,則不需要變更他們的腳本。
累計 FP8 參數的梯度#
同一個參數使用在多分支時,autograd 機制會將這些分支的 gradient 加總。這在像 pipeline parallelism 等的情況很常見,每個微批次使用同一組參數做 minibatch。然而,對於 _overwrite_with_gradient
參數,這種透過加總的累積沒有意義。相反地,我們偏好透過取最大值來進行客製累積。
為了處理這一點,我們引入一個自定義數據型態 fp8_ops.fp32_max_grad
。以下展示基本的用法
fmax32 = fp8_ops.fp32_max_grad
def reuse_fp8_param(x, y, scale, amax_history):
scale = scale.astype(fmax32)
amax_history = amax_history.astype(fmax32)
x = fp8_ops.in_qdq(f32, e4m3, x, scale, amax_history)
y = fp8_ops.in_qdq(f32, e4m3, y, scale, amax_history)
return x + y
reuse_fp8_param_fn = jax.grad(reuse_fp8_param, (0, 1, 2, 3))
reuse_fp8_param_fn = jax.jit(reuse_fp8_param_fn)
_, _, new_ah, new_sf = reuse_fp8_param_fn(2.0, 3.0, a_scale, a_amax_hist)
print(new_ah, new_sf)
在此範例中,我們首先將 scale
和 amax_history
轉換成 fp8_ops.fp32_max_grad
,然後使用同一組 scale
和 amax_history
呼叫 fp8_ops.in_qdq
兩次。在 autograd 過程中,我們會取各個分支的 gradient 的最大值,得到以下正確結果
1.0 [3. 0. 0. ... 0. 0. 0.]
如果我們不執行類型轉換,則會得到以下結果,代表兩個分支的 gradient 加總了
2.0 [5. 0. 0. ... 0. 0. 0.]
如果使用者選擇使用高級 API,此轉換已包含在其中。