专栏名称: GitHubStore
分享有意思的开源项目
目录
相关文章推荐
蓝钻故事  ·  42年前旧作被扒,这一幕震撼上亿人 ·  昨天  
十点读书  ·  30-50岁投资自己最清醒的方式 ·  4 天前  
蓝钻故事  ·  马云过关了 ·  3 天前  
十点读书会  ·  排队3小时,云贵菜正在血洗北上广CBD ·  3 天前  
51好读  ›  专栏  ›  GitHubStore

Open-Sora:开源版的Sora

GitHubStore  · 公众号  ·  · 2024-03-06 08:34

正文

项目简介

本项目希望通过开源社区的力量复现Sora,由北大-兔展AIGC联合实验室共同发起,当前我们资源有限仅搭建了基础架构,无法进行完整训练,希望通过开源社区逐步增加模块并筹集资源进行训练,当前版本离目标差距巨大,仍需持续完善和快速迭代,欢迎Pull request!!!


项目阶段:

  • 基本的

  1. 设置代码库并在景观数据集上训练无条件模型。

  2. 训练可提高分辨率和持续时间的模型。

  • 扩展

  1. 在景观数据集上进行text2video实验。

  2. 在 video2text 数据集上训练 1080p 模型。

  3. 具有更多条件的控制模型。


仓库结构

├── README.md├── docs│   ├── Data.md                    -> Datasets description.│   ├── Contribution_Guidelines.md -> Contribution guidelines description.├── scripts                        -> All training scripts.│   └── train.sh├── sora│   ├── dataset                    -> Dataset code to read videos│   ├── models │   │   ├── captioner               │   │   ├── super_resolution        │   ├── modules│   │   ├── ae                     -> compress videos to latents│   │   │   ├── vqvae│   │   │   ├── vae│   │   ├── diffusion              -> denoise latents│   │   │   ├── dit│   │   │   ├── unet|   ├── utils.py                   │   ├── train.py                   -> Training code


要求和安装

推荐要求如下。

  • Python >= 3.8

  • Pytorch >= 1.13.1

  • CUDA 版本 >= 11.7

  • 安装所需的包:


git clone https://github.com/PKU-YuanGroup/Open-Sora-Plancd Open-Sora-Planconda create -n opensora python=3.8 -yconda activate opensorapip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117pip install -r requirements.txtcd VideoGPTpip install -e .cd ..


用法

数据集

参考Data.md

Video-VQVAE (VideoGPT)


训练

cd src/sora/modules/ae/vqvae/videogpt


请参阅原始存储库。使用 scripts/train_vqvae.py 脚本训练 Video-VQVAE。执行 python scripts/train_vqvae.py -h 以获取有关所有可用训练设置的信息。下面列出了更多相关设置的子集以及默认值。


VQ-VAE 特定设置
  • --embedding_dim :码本嵌入的维数

  • --n_codes 2048 :码本中的代码数量

  • --n_hiddens 240 :残差块中隐藏特征的数量

  • --n_res_layers 4 :剩余块的数量

  • --downsample 4 4 4 :编码器的 T H W 下采样步长


训练设置
  • --gpus 2 :分布式训练的GPU数量

  • --sync_batchnorm :使用 > 1 GPU 时使用 SyncBatchNorm 而不是 BatchNorm3d

  • --gradient_clip_val 1 :训练的梯度裁剪阈值

  • --batch_size 16 :每个 GPU 的批量大小

  • --num_workers 8 :每个 DataLoader 的工作人员数量


数据集设置






请到「今天看啥」查看全文