LoRA#
NNX LoRA 類別。
- class flax.nnx.LoRA(*args, **kwargs)[原始碼]#
一個獨立的 LoRA 層。
使用範例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.LoRA(3, 2, 4, rngs=nnx.Rngs(0)) >>> layer.lora_a.value.shape (3, 2) >>> layer.lora_b.value.shape (2, 4) >>> # Wrap around existing layer >>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0)) >>> wrapper = nnx.LoRA(3, 2, 4, base_module=linear, rngs=nnx.Rngs(1)) >>> assert wrapper.base_module == linear >>> wrapper.lora_a.value.shape (3, 2) >>> layer.lora_b.value.shape (2, 4) >>> y = layer(jnp.ones((16, 3))) >>> y.shape (16, 4)
- in_features#
輸入特徵的數量。
- lora_rank#
LoRA 維度的秩。
- out_features#
輸出特徵的數量。
- base_module#
一個基礎模組,如果可能的話,可以用來呼叫和替換。
- dtype#
計算的資料類型(預設:從輸入和參數推斷)。
- param_dtype#
傳遞給參數初始化器的資料類型(預設:float32)。
- precision#
計算的數值精度,詳細資訊請參閱 jax.lax.Precision。
- kernel_init#
權重矩陣的初始化器函數。
- lora_param_type#
LoRA 參數的類型。
方法
- class flax.nnx.LoRALinear(*args, **kwargs)[原始碼]#
一個 nnx.Linear 層,其中輸出將被 LoRA 化。
模型狀態結構將與 Linear 的結構相容。
使用範例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0)) >>> lora_linear = nnx.LoRALinear(3, 4, lora_rank=2, rngs=nnx.Rngs(0)) >>> linear.kernel.value.shape (3, 4) >>> lora_linear.kernel.value.shape (3, 4) >>> lora_linear.lora.lora_a.value.shape (3, 2) >>> jnp.allclose(linear.kernel.value, lora_linear.kernel.value) Array(True, dtype=bool) >>> y = lora_linear(jnp.ones((16, 3))) >>> y.shape (16, 4)
- in_features#
輸入特徵的數量。
- out_features#
輸出特徵的數量。
- lora_rank#
LoRA 維度的秩。
- base_module#
一個基礎模組,如果可能的話,可以用來呼叫和替換。
- dtype#
計算的資料類型(預設:從輸入和參數推斷)。
- param_dtype#
傳遞給參數初始化器的資料類型(預設:float32)。
- precision#
計算的數值精度,詳細資訊請參閱 jax.lax.Precision。
- kernel_init#
權重矩陣的初始化器函數。
- lora_param_type#
LoRA 參數的類型。
方法