LoRA#

NNX LoRA 類別。

class flax.nnx.LoRA(*args, **kwargs)[原始碼]#

一個獨立的 LoRA 層。

使用範例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.LoRA(3, 2, 4, rngs=nnx.Rngs(0))
>>> layer.lora_a.value.shape
(3, 2)
>>> layer.lora_b.value.shape
(2, 4)
>>> # Wrap around existing layer
>>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0))
>>> wrapper = nnx.LoRA(3, 2, 4, base_module=linear, rngs=nnx.Rngs(1))
>>> assert wrapper.base_module == linear
>>> wrapper.lora_a.value.shape
(3, 2)
>>> layer.lora_b.value.shape
(2, 4)
>>> y = layer(jnp.ones((16, 3)))
>>> y.shape
(16, 4)
in_features#

輸入特徵的數量。

lora_rank#

LoRA 維度的秩。

out_features#

輸出特徵的數量。

base_module#

一個基礎模組,如果可能的話,可以用來呼叫和替換。

dtype#

計算的資料類型(預設:從輸入和參數推斷)。

param_dtype#

傳遞給參數初始化器的資料類型(預設:float32)。

precision#

計算的數值精度,詳細資訊請參閱 jax.lax.Precision

kernel_init#

權重矩陣的初始化器函數。

lora_param_type#

LoRA 參數的類型。

__call__(x)[原始碼]#

將自身作為函數呼叫。

方法

class flax.nnx.LoRALinear(*args, **kwargs)[原始碼]#

一個 nnx.Linear 層,其中輸出將被 LoRA 化。

模型狀態結構將與 Linear 的結構相容。

使用範例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0))
>>> lora_linear = nnx.LoRALinear(3, 4, lora_rank=2, rngs=nnx.Rngs(0))
>>> linear.kernel.value.shape
(3, 4)
>>> lora_linear.kernel.value.shape
(3, 4)
>>> lora_linear.lora.lora_a.value.shape
(3, 2)
>>> jnp.allclose(linear.kernel.value, lora_linear.kernel.value)
Array(True, dtype=bool)
>>> y = lora_linear(jnp.ones((16, 3)))
>>> y.shape
(16, 4)
in_features#

輸入特徵的數量。

out_features#

輸出特徵的數量。

lora_rank#

LoRA 維度的秩。

base_module#

一個基礎模組,如果可能的話,可以用來呼叫和替換。

dtype#

計算的資料類型(預設:從輸入和參數推斷)。

param_dtype#

傳遞給參數初始化器的資料類型(預設:float32)。

precision#

計算的數值精度,詳細資訊請參閱 jax.lax.Precision

kernel_init#

權重矩陣的初始化器函數。

lora_param_type#

LoRA 參數的類型。

__call__(x)[原始碼]#

沿著最後一個維度對輸入應用線性轉換。

參數

inputs – 要轉換的 nd-array。

回傳值

轉換後的輸入。

方法