项目简介
大型语言模型 (LLMs) 用于基于检索的任务(尤其是在检索增强生成 (RAG) 中)的缩放面临严重的内存限制,尤其是在微调大量提示序列时。当前的开源库支持跨多个 GPU 的全模型推理和微调,但无法适应检索上下文所需的有效参数分布。为了解决这一差距,我们引入了一种新的框架,用于利用分布式训练对 Llama-2 模型进行 PEFT 兼容微调。我们的框架独特地利用 JAX 的即时 (JIT) 编译和张量分片来实现高效的资源管理,从而在降低内存需求的情况下实现加速微调。这一进步显著提高了复杂 RAG 应用程序微LLMs调的可扩展性和可行性,即使在 GPU 资源有限的系统上也是如此。我们的实验表明,与使用 4 个 GPU 的 Hugging Face/DeepSpeed 实现相比,运行时间提高了 12 倍以上,而每个 GPU 消耗的 VRAM 不到一半。
安装
请确保您安装了最新版本的 jax for GPU。https://github.com/google/jax
若要安装软件包,请在存储库的根目录中运行以下命令:
git clone https://github.com/aniquetahir/JORA.git
cd JORA
pip install -e .
确保 Jax 可以访问 GPU:
import jax
print(jax.devices())
用法
该库可以通过 python 使用,或者提供 gui。
用作库
Parallama 类可用于定义配置。明智的参数设置为默认值。
class ParallamaConfig(NamedTuple):
JAX_PARAMS_PATH: str
LLAMA2_META_PATH: str # e.g. '/tmp/llama2-13B'
MODEL_SIZE: str # '7B', '13B', '70B'
NUM_GPUS: int = None
LORA_R: int = 16
LORA_ALPHA: int = 16
LORA_DROPOUT: float = 0.05
LR: float = 0.0001
BATCH_SIZE: int = 1
N_ACCUMULATION_STEPS: int = 8
MAX_SEQ_LEN = 2000
N_EPOCHS: int = 7
SEED: int = 420
用法示例
基于 Llama-2 的模型
from jora import train_lora, ParallamaConfig, generate_alpaca_dataset
config = ParallamaConfig(MODEL_SIZE=model_size, JAX_PARAMS_PATH=jax_path,
LLAMA2_META_PATH=hf_path)
dataset = generate_alpaca_dataset(dataset_path, 'train', config)
train_lora(config, dataset, checkpoint_path)
基于 Gemma 的模型 Flax Gemma 模型可以从 Kaggle 下载:
import kagglehub
VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it', '1.1-2b-it', '1.1-7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
默认情况下,kagglehub 将模型存储在目录中 ~/.cache/kagglehub 。
对于 Gemma 1.1 模型,KaggleHub 将模型存储在以下目录结构中:
1.1-7b-it
├── 1
│ ├── 7b-it
│ └── tokenizer.model
└── 1.complete
因此
config.MODEL_VERSION
,应
7b-it
设置为 for
1.1-7b-it
model。
该
generate_alpaca_dataset
函数用于从 Alpaca 格式的 json 文件生成数据集。这有助于 instruct 格式训练,因为数据集处理、标记化和批处理由库处理。或者,火炬
Dataset
和
DataLoader
可用于自定义数据集。