专栏名称: GitHubStore
分享有意思的开源项目
目录
相关文章推荐
51好读  ›  专栏  ›  GitHubStore

JORA:解决了 LLM 在 RAG 中的内存限制问题

GitHubStore  · 公众号  ·  · 2024-04-16 07:37

正文

项目简介


大型语言模型 (LLMs) 用于基于检索的任务(尤其是在检索增强生成 (RAG) 中)的缩放面临严重的内存限制,尤其是在微调大量提示序列时。当前的开源库支持跨多个 GPU 的全模型推理和微调,但无法适应检索上下文所需的有效参数分布。为了解决这一差距,我们引入了一种新的框架,用于利用分布式训练对 Llama-2 模型进行 PEFT 兼容微调。我们的框架独特地利用 JAX 的即时 (JIT) 编译和张量分片来实现高效的资源管理,从而在降低内存需求的情况下实现加速微调。这一进步显著提高了复杂 RAG 应用程序微LLMs调的可扩展性和可行性,即使在 GPU 资源有限的系统上也是如此。我们的实验表明,与使用 4 个 GPU 的 Hugging Face/DeepSpeed 实现相比,运行时间提高了 12 倍以上,而每个 GPU 消耗的 VRAM 不到一半。

安装

请确保您安装了最新版本的 jax for GPU。https://github.com/google/jax


若要安装软件包,请在存储库的根目录中运行以下命令:

git clone https://github.com/aniquetahir/JORA.gitcd JORApip install -e .

确保 Jax 可以访问 GPU:

import jaxprint(jax.devices())


用法

该库可以通过 python 使用,或者提供 gui。

用作库

Parallama 类可用于定义配置。明智的参数设置为默认值。

class ParallamaConfig(NamedTuple):    JAX_PARAMS_PATH: str    LLAMA2_META_PATH: str # e.g. '/tmp/llama2-13B'    MODEL_SIZE: str # '7B', '13B', '70B'    NUM_GPUS: int = None    LORA_R: int = 16    LORA_ALPHA: int = 16    LORA_DROPOUT: float = 0.05    LR: float = 0.0001    BATCH_SIZE: int = 1    N_ACCUMULATION_STEPS: int = 8    MAX_SEQ_LEN = 2000    N_EPOCHS: int = 7    SEED: int = 420

用法示例

基于 Llama-2 的模型

from jora import train_lora, ParallamaConfig, generate_alpaca_dataset
config = ParallamaConfig(MODEL_SIZE=model_size, JAX_PARAMS_PATH=jax_path, LLAMA2_META_PATH=hf_path)dataset = generate_alpaca_dataset(dataset_path, 'train', config)train_lora(config, dataset, checkpoint_path)

基于 Gemma 的模型 Flax Gemma 模型可以从 Kaggle 下载:

import kagglehubVARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it', '1.1-2b-it', '1.1-7b-it'] {type:"string"}weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')

默认情况下,kagglehub 将模型存储在目录中 ~/.cache/kagglehub 。

对于 Gemma 1.1 模型,KaggleHub 将模型存储在以下目录结构中:

1.1-7b-it├── 1│   ├── 7b-it│   └── tokenizer.model└── 1.complete

因此 config.MODEL_VERSION ,应 7b-it 设置为 for 1.1-7b-it model。

generate_alpaca_dataset 函数用于从 Alpaca 格式的 json 文件生成数据集。这有助于 instruct 格式训练,因为数据集处理、标记化和批处理由库处理。或者,火炬 Dataset DataLoader 可用于自定义数据集。







请到「今天看啥」查看全文