範例:使用預訓練的 Gemma 模型搭配 Flax NNX 進行推論#
本範例示範如何使用 Flax NNX 載入 Gemma 開放模型檔案,並使用它們執行取樣/推論以產生文字。您將使用以 Flax 和 JAX 編寫的 Flax NNX gemma
模組進行模型參數配置和推論。
建議您使用可存取 A100 GPU 加速的 Google Colab 來執行程式碼。
安裝#
安裝必要的依賴項,包括 kagglehub
。
! pip install --no-deps -U flax
! pip install jaxtyping kagglehub treescope
下載模型#
若要使用 Gemma 模型,您需要一個 Kaggle 帳戶和 API 金鑰
若要建立帳戶,請造訪 Kaggle 並點擊「註冊」。
如果/一旦您擁有帳戶,您需要登入,前往您的 「設定」,並在「API」下點擊「建立新權杖」以產生並下載您的 Kaggle API 金鑰。
在 Google Colab 中,在「機密」下新增您的 Kaggle 使用者名稱和 API 金鑰,將使用者名稱儲存為
KAGGLE_USERNAME
,金鑰儲存為KAGGLE_KEY
。如果您正在使用 Kaggle Notebook 進行免費 TPU 或其他硬體加速,它在「附加元件」>「機密」下有一個金鑰儲存功能,以及存取儲存金鑰的說明。
然後執行下面的儲存格。
import kagglehub
kagglehub.login()
如果一切順利,它應該會顯示 Kaggle credentials set. Kaggle credentials successfully validated.
。
注意:在 Google Colab 中,您也可以在完成上述可選步驟 3 後,使用下面的程式碼驗證 Kaggle。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
現在,載入您想要嘗試的 Gemma 模型。下一個儲存格中的程式碼會利用 kagglehub.model_download
來下載模型檔案。
注意:對於較大的模型,例如 gemma 7b
和 gemma 7b-it
(指示),您可能需要具有大量記憶體的硬體加速器,例如 NVIDIA A100。
from IPython.display import clear_output
VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = f'{weights_dir}/{VARIANT}'
vocab_path = f'{weights_dir}/tokenizer.model'
Python 導入#
from flax import nnx
import sentencepiece as spm
若要與 Gemma 模型互動,您將使用來自 google/flax
GitHub 範例 的 Flax NNX gemma
程式碼。由於它未作為套件公開,您需要使用以下變通方法從 GitHub 上的 Flax NNX examples/gemma
導入。
import sys
import tempfile
with tempfile.TemporaryDirectory() as tmp:
# Create a temporary directory and clone the `flax` repo.
# Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.
! git clone https://github.com/google/flax.git {tmp}/flax
sys.path.append(f"{tmp}/flax/examples/gemma")
import params as params_lib
import sampler as sampler_lib
import transformer as transformer_lib
sys.path.pop();
Cloning into '/tmp/tmp_68d13pv/flax'...
remote: Enumerating objects: 31912, done.
remote: Counting objects: 100% (605/605), done.
remote: Compressing objects: 100% (250/250), done.
remote: Total 31912 (delta 406), reused 503 (delta 352), pack-reused 31307 (from 1)
Receiving objects: 100% (31912/31912), 23.92 MiB | 18.17 MiB/s, done.
Resolving deltas: 100% (23869/23869), done.
載入並準備 Gemma 模型#
首先,載入 Gemma 模型參數以供 Flax 使用。
params = params_lib.load_and_format_params(ckpt_path)
接下來,載入使用 SentencePiece 程式庫建構的 tokenizers 檔案。
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
True
然後,使用 Flax NNX gemma.transformer.TransformerConfig.from_params
函數從檢查點自動載入正確的配置。
注意:由於此版本中未使用 token,詞彙表大小小於輸入嵌入的數量。
transformer = transformer_lib.Transformer.from_params(params)
nnx.display(transformer)
執行取樣/推論#
在您的模型和 tokenizer 上,使用正確的參數形狀建構 Flax NNX gemma.Sampler
。
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
)
您已準備好開始取樣!
注意:此 Flax NNX gemma.Sampler
使用 JAX 的 即時 (JIT) 編譯,因此變更輸入形狀會觸發重新編譯,這可能會減慢速度。為了獲得最快和最有效率的結果,請保持批次大小一致。
在 input_batch
中編寫提示並執行推論。您可以隨意調整 total_generation_steps
(產生回應時執行的步驟數)。
input_batch = [
"\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
]
out_data = sampler(
input_strings=input_batch,
total_generation_steps=300, # The number of steps performed when generating a response.
)
for input_string, out_string in zip(input_batch, out_data.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
print()
print(10*'#')
Prompt:
# Python program for implementation of Bubble Sort
def bubbleSort(arr):
Output:
for i in range(len(arr)):
for j in range(len(arr) - i - 1):
if arr[j] > arr[j + 1]:
swap(arr, j, j + 1)
def swap(arr, i, j):
temp = arr[i]
arr[i] = arr[j]
arr[j] = temp
# Driver code
arr = [5, 2, 8, 3, 1, 9]
print("Unsorted array:")
print(arr)
bubbleSort(arr)
print("Sorted array:")
print(arr)
# Time complexity of Bubble sort O(n^2)
# where n is the length of the array
# Space complexity of Bubble sort O(1)
# as it only requires constant extra space for the swap operation
# This program uses the bubble sort algorithm to sort the given array in ascending order.
```python
# This program uses the bubble sort algorithm to sort the given array in ascending order.
def bubbleSort(arr):
for i in range(len(arr)):
for j in range(len(arr) - i - 1):
if arr[j] > arr[j + 1]:
swap(arr, j, j + 1)
def swap(
##########
您應該會取得氣泡排序演算法的 Python 實作。