範例:使用預訓練的 Gemma 模型搭配 Flax NNX 進行推論

範例:使用預訓練的 Gemma 模型搭配 Flax NNX 進行推論#

本範例示範如何使用 Flax NNX 載入 Gemma 開放模型檔案,並使用它們執行取樣/推論以產生文字。您將使用以 Flax 和 JAX 編寫的 Flax NNX gemma 模組進行模型參數配置和推論。

Gemma 是一系列輕量級、最先進的開放模型,基於 Google DeepMind 的 Gemini。閱讀更多關於 GemmaGemma 2 的資訊。

建議您使用可存取 A100 GPU 加速的 Google Colab 來執行程式碼。

安裝#

安裝必要的依賴項,包括 kagglehub

! pip install --no-deps -U flax
! pip install jaxtyping kagglehub treescope

下載模型#

若要使用 Gemma 模型,您需要一個 Kaggle 帳戶和 API 金鑰

  1. 若要建立帳戶,請造訪 Kaggle 並點擊「註冊」。

  2. 如果/一旦您擁有帳戶,您需要登入,前往您的 「設定」,並在「API」下點擊「建立新權杖」以產生並下載您的 Kaggle API 金鑰。

  3. Google Colab 中,在「機密」下新增您的 Kaggle 使用者名稱和 API 金鑰,將使用者名稱儲存為 KAGGLE_USERNAME,金鑰儲存為 KAGGLE_KEY。如果您正在使用 Kaggle Notebook 進行免費 TPU 或其他硬體加速,它在「附加元件」>「機密」下有一個金鑰儲存功能,以及存取儲存金鑰的說明。

然後執行下面的儲存格。

import kagglehub
kagglehub.login()

如果一切順利,它應該會顯示 Kaggle credentials set. Kaggle credentials successfully validated.

注意:在 Google Colab 中,您也可以在完成上述可選步驟 3 後,使用下面的程式碼驗證 Kaggle。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

現在,載入您想要嘗試的 Gemma 模型。下一個儲存格中的程式碼會利用 kagglehub.model_download 來下載模型檔案。

注意:對於較大的模型,例如 gemma 7bgemma 7b-it (指示),您可能需要具有大量記憶體的硬體加速器,例如 NVIDIA A100。

from IPython.display import clear_output

VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = f'{weights_dir}/{VARIANT}'
vocab_path = f'{weights_dir}/tokenizer.model'

Python 導入#

from flax import nnx
import sentencepiece as spm

若要與 Gemma 模型互動,您將使用來自 google/flax GitHub 範例 的 Flax NNX gemma 程式碼。由於它未作為套件公開,您需要使用以下變通方法從 GitHub 上的 Flax NNX examples/gemma 導入。

import sys
import tempfile
with tempfile.TemporaryDirectory() as tmp:
  # Create a temporary directory and clone the `flax` repo.
  # Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.
  ! git clone https://github.com/google/flax.git {tmp}/flax
  sys.path.append(f"{tmp}/flax/examples/gemma")
  import params as params_lib
  import sampler as sampler_lib
  import transformer as transformer_lib
  sys.path.pop();
Cloning into '/tmp/tmp_68d13pv/flax'...
remote: Enumerating objects: 31912, done.
remote: Counting objects: 100% (605/605), done.
remote: Compressing objects: 100% (250/250), done.
remote: Total 31912 (delta 406), reused 503 (delta 352), pack-reused 31307 (from 1)
Receiving objects: 100% (31912/31912), 23.92 MiB | 18.17 MiB/s, done.
Resolving deltas: 100% (23869/23869), done.

載入並準備 Gemma 模型#

首先,載入 Gemma 模型參數以供 Flax 使用。

params = params_lib.load_and_format_params(ckpt_path)

接下來,載入使用 SentencePiece 程式庫建構的 tokenizers 檔案。

vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
True

然後,使用 Flax NNX gemma.transformer.TransformerConfig.from_params 函數從檢查點自動載入正確的配置。

注意:由於此版本中未使用 token,詞彙表大小小於輸入嵌入的數量。

transformer = transformer_lib.Transformer.from_params(params)
nnx.display(transformer)

執行取樣/推論#

在您的模型和 tokenizer 上,使用正確的參數形狀建構 Flax NNX gemma.Sampler

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
)

您已準備好開始取樣!

注意:此 Flax NNX gemma.Sampler 使用 JAX 的 即時 (JIT) 編譯,因此變更輸入形狀會觸發重新編譯,這可能會減慢速度。為了獲得最快和最有效率的結果,請保持批次大小一致。

input_batch 中編寫提示並執行推論。您可以隨意調整 total_generation_steps (產生回應時執行的步驟數)。

input_batch = [
    "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
  ]

out_data = sampler(
    input_strings=input_batch,
    total_generation_steps=300,  # The number of steps performed when generating a response.
  )

for input_string, out_string in zip(input_batch, out_data.text):
  print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
  print()
  print(10*'#')
Prompt:

# Python program for implementation of Bubble Sort

def bubbleSort(arr):
Output:

    for i in range(len(arr)):
        for j in range(len(arr) - i - 1):
            if arr[j] > arr[j + 1]:
                swap(arr, j, j + 1)


def swap(arr, i, j):
    temp = arr[i]
    arr[i] = arr[j]
    arr[j] = temp


# Driver code
arr = [5, 2, 8, 3, 1, 9]
print("Unsorted array:")
print(arr)
bubbleSort(arr)
print("Sorted array:")
print(arr)


# Time complexity of Bubble sort O(n^2)
# where n is the length of the array


# Space complexity of Bubble sort O(1)
# as it only requires constant extra space for the swap operation


# This program uses the bubble sort algorithm to sort the given array in ascending order.

```python
# This program uses the bubble sort algorithm to sort the given array in ascending order.

def bubbleSort(arr):
    for i in range(len(arr)):
        for j in range(len(arr) - i - 1):
            if arr[j] > arr[j + 1]:
                swap(arr, j, j + 1)


def swap(

##########

您應該會取得氣泡排序演算法的 Python 實作。