公众号ID
|
计算机视觉研究院
学习群
|
扫码在主页获取加入方式
Column of Computer Vision Institute
OpenAI 的 Sora、Stability AI 的 Stable Video Diffusion 以及许多其他已经发布或未来将出现的文本生成视频模型,是继大语言模型 (LLM) 之后 2024 年最流行的 AI 趋势之一。
在这篇博客中,作者将展示如何将从头开始构建一个小规模的文本生成视频模型,涵盖了从理解理论概念、到编写整个架构再到生成最终结果的所有内容。
由于作者没有大算力的 GPU,所以仅编写了小规模架构。以下是在不同处理器上训练模型所需时间的比较。
作者表示,在 CPU 上运行显然需要更长的时间来训练模型。如果你需要快速测试代码中的更改并查看结果,CPU 不是最佳选择。因此建议使用 Colab 或 Kaggle 的 T4 GPU 进行更高效、更快速的训练。
我们采用了与传统机器学习或深度学习模型类似的方法,即在数据集上进行训练,然后在未见过数据上进行测试。在文本转视频的背景下,假设有一个包含 10 万个狗捡球和猫追老鼠视频的训练数据集,然后训练模型来生成猫捡球或狗追老鼠的视频。
虽然此类训练数据集在互联网上很容易获得,但所需的算力极高。因此,我们将使用由 Python 代码生成的移动对象视频数据集。同时使用 GAN(生成对抗网络)架构来创建模型,而不是 OpenAI Sora 使用的扩散模型。
我们也尝试使用扩散模型,但内存要求超出了自己的能力。另一方面,GAN 可以更容易、更快地进行训练和测试。
我们将使用 OOP(面向对象编程),因此必须对它以及神经网络有基本的了解。此外 GAN(生成对抗网络)的知识不是必需的,因为这里简单介绍它们的架构。
-
OOP:https://www.youtube.com/watch?v=q2SGW2VgwAM
-
神经网络理论:https://www.youtube.com/watch?v=Jy4wM2X21u0
-
GAN 架构:https://www.youtube.com/watch?v=TpMIssRdhco
-
Python 基础:https://www.youtube.com/watch?v=eWRfhZUzrAc
生成对抗网络是一种深度学习模型,其中两个神经网络相互竞争:一个从给定的数据集创建新数据(如图像或音乐),另一个则判断数据是真实的还是虚假的。这个过程一直持续到生成的数据与原始数据无法区分。
-
生成图像:GAN 根据文本 prompt 创建逼真的图像或修改现有图像,例如增强分辨率或为黑白照片添加颜色。
-
数据增强:GAN 生成合成数据来训练其他机器学习模型,例如为欺诈检测系统创建欺诈交易数据。
-
补充缺失信息:GAN 可以填充缺失数据,例如根据地形图生成地下图像以用于能源应用。
-
生成 3D 模型:GAN 将 2D 图像转换为 3D 模型,在医疗保健等领域非常有用,可用于为手术规划创建逼真的器官图像。
GAN 由两个深度神经网络组成:生成器和判别器。这两个网络在对抗设置中一起训练,其中一个网络生成新数据,另一个网络评估数据是真是假。
让我们以图像到图像的转换为例,解释一下 GAN 模型,重点是修改人脸。
2. 属性修改:生成器会修改人脸的属性,比如给眼睛加上墨镜。
3. 生成图像:生成器会创建一组添加了太阳镜的图像。
4. 判别器的任务:判别器接收到混合的真实图像(带有太阳镜的人)和生成的图像(添加了太阳镜的人脸)。
6. 反馈回路:如果判别器正确识别出假图像,生成器会调整其参数以生成更逼真的图像。如果生成器成功欺骗了判别器,判别器会更新其参数以提高检测能力。
通过这一对抗过程,两个网络都在不断改进。生成器越来越善于生成逼真的图像,而判别器则越来越善于识别假图像,直到达到平衡,判别器再也无法区分真实图像和生成的图像。此时,GAN 已成功学会生成逼真的修改图像。
我们将使用一系列 Python 库,让我们导入它们。
import os
import random
import numpy as np
import cv2
from PIL import Image, ImageDraw, ImageFont
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from IPython.display import clear_output, display, HTML
import base64
现在我们已经导入了所有的库,下一步就是定义我们的训练数据,用于训练 GAN 架构。
我们需要至少 10000 个视频作为训练数据。为什么呢?因为我测试了较小数量的视频,结果非常糟糕,几乎没有任何效果。下一个重要问题是:这些视频内容是什么? 我们的训练视频数据集包括一个圆圈以不同方向和不同运动方式移动的视频。让我们来编写代码并生成 10,000 个视频,看看它的效果如何。
os.makedirs('training_dataset', exist_ok=True)
num_videos = 10000
frames_per_video = 10
img_size = (64, 64)
shape_size = 10
设置一些基本参数后,接下来我们需要定义训练数据集的文本 prompt,并据此生成训练视频。
prompts_and_movements = [
("circle moving down", "circle", "down"),
("circle moving left", "circle", "left"),
("circle moving right", "circle", "right"),
("circle moving diagonally up-right", "circle", "diagonal_up_right"),
("circle moving diagonally down-left"
, "circle", "diagonal_down_left"),
("circle moving diagonally up-left", "circle", "diagonal_up_left"),
("circle moving diagonally down-right", "circle", "diagonal_down_right"),
("circle rotating clockwise", "circle", "rotate_clockwise"),
("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"),
("circle shrinking", "circle", "shrink"),
("circle expanding", "circle", "expand"),
("circle bouncing vertically", "circle", "bounce_vertical"),
("circle bouncing horizontally", "circle", "bounce_horizontal"),
("circle zigzagging vertically", "circle", "zigzag_vertical"),
("circle zigzagging horizontally", "circle", "zigzag_horizontal"),
("circle moving up-left", "circle", "up_left"),
("circle moving down-right", "circle", "down_right"),
("circle moving down-left", "circle", "down_left"),
]
我们已经利用这些 prompt 定义了圆的几个运动轨迹。现在,我们需要编写一些数学公式,以便根据 prompt 移动圆。
def create_image_with_moving_shape(size, frame_num, shape, direction):
img = Image.new('RGB', size, color=(255, 255, 255))
draw = ImageDraw.Draw(img)
center_x, center_y = size[0] // 2, size[1] // 2
position = (center_x, center_y)
direction_map = {
"down": (0, frame_num * 5 % size[1]),
"left": (-frame_num * 5 % size[0], 0),
"right": (frame_num * 5 % size[0], 0),
"diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),
"diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]),
"diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),
"diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),
"rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),
"rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),
"bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)),
"bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0),
"zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]),
"zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y),
"up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),
"up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),
"down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),
"down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1])
}
if direction in direction_map:
if isinstance(direction_map[direction], tuple):
position = tuple(np.add(position, direction_map[direction]))
else:
img = direction_map[direction]
return np.array(img)
上述函数用于根据所选方向在每一帧中移动我们的圆。我们只需在其上运行一个循环,直至生成所有视频的次数。
for i in range(num_videos):
prompt, shape, direction = random.choice(prompts_and_movements)
video_dir = f'training_dataset/video_{i}'
os.makedirs(video_dir, exist_ok=True)
with open(f'{video_dir}/prompt.txt', 'w') as f:
f.write(prompt)
for frame_num in range(frames_per_video):
img = create_image_with_moving_shape(img_size, frame_num, shape, direction)
cv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)
运行上述代码后,就会生成整个训练数据集。以下是训练数据集文件的结构。
每个训练视频文件夹包含其帧以及对应的文本 prompt。让我们看一下我们的训练数据集样本。
在我们的训练数据集中,我们没有包含圆圈先向上移动然后向右移动的运动。我们将使用这个作为测试 prompt,来评估我们训练的模型在未见过的数据上的表现。
还有一个重要的要点需要注意,我们的训练数据包含许多物体从场景中移出或部分出现在摄像机前方的样本,类似于我们在 OpenAI Sora 演示视频中观察到的情况。
在我们的训练数据中包含此类样本的原因是为了测试当圆圈从角落进入场景时,模型是否能够保持一致性而不会破坏其形状。
现在我们的训练数据已经生成,需要将训练视频转换为张量,这是 PyTorch 等深度学习框架中使用的主要数据类型。此外,通过将数据缩放到较小的范围,执行归一化等转换有助于提高训练架构的收敛性和稳定性。
我们必须为文本转视频任务编写一个数据集类,它可以从训练数据集目录中读取视频帧及其相应的文本 prompt,使其可以在 PyTorch 中使用。
class TextToVideoDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
self.frame_paths = []
self.prompts = []
for video_dir in self.video_dirs:
frames = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')]
self.frame_paths.extend(frames)
with open(os.path.join(video_dir, 'prompt.txt'), 'r') as f:
prompt = f.read().strip()
self.prompts.extend([prompt] * len(frames))
def __len__(self):
return len(self.frame_paths)
def __getitem__(self, idx):
frame_path = self.frame_paths[idx]
image = Image.open(frame_path)
prompt = self.prompts[idx]
if self.transform:
image = self.transform(image)
return image, prompt
在继续编写架构代码之前,我们需要对训练数据进行归一化处理。我们使用 16 的 batch 大小并对数据进行混洗以引入更多随机性。
你可能已经看到,在 Transformer 架构中,起点是将文本输入转换为嵌入,从而在多头注意力中进行进一步处理。类似地,我们在这里必须编写一个文本嵌入层。基于该层,GAN 架构训练在我们的嵌入数据和图像张量上进行。
class TextEmbedding(nn.Module):
def __init__(self, vocab_size, embed_size):
super(TextEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
def forward(self, x):
return self.embedding(x)
词汇量将基于我们的训练数据,在稍后进行计算。嵌入大小将为 10。如果使用更大的数据集,你还可以使用 Hugging Face 上已有的嵌入模型。
现在我们已经知道生成器在 GAN 中的作用,接下来让我们对这一层进行编码,然后了解其内容。
class Generator(nn.Module):
def __init__(self, text_embed_size):
super(Generator, self).__init__()
self.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8)
self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1)
self.relu = nn.ReLU(True)
self.tanh = nn.Tanh()
def forward(self, noise, text_embed):
x = torch.cat((noise, text_embed), dim=1)
x = self.fc1(x).view(-1, 256, 8, 8)
x = self.relu(self.deconv1(x))
x = self.relu(self.deconv2(x))
x = self.tanh(self.deconv3(x))
return x
该 Generator 类负责根据随机噪声和文本嵌入的组合创建视频帧,旨在根据给定的文本描述生成逼真的视频帧。该网络从完全连接层 (nn.Linear) 开始,将噪声向量和文本嵌入组合成单个特征向量。然后,该向量被重新整形并经过一系列的转置卷积层 (nn.ConvTranspose2d),这些层将特征图逐步上采样到所需的视频帧大小。
这些层使用 ReLU 激活 (nn.ReLU) 实现非线性,最后一层使用 Tanh 激活 (nn.Tanh) 将输出缩放到 [-1, 1] 的范围。因此,生成器将抽象的高维输入转换为以视觉方式表示输入文本的连贯视频帧。
在编写完生成器层之后,我们需要实现另一半,即判别器部分。
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)
self.conv2 = nn.Conv2d(64, 128, 4, 2, 1)
self.conv3 = nn.Conv2d(128, 256, 4, 2, 1)
self.fc1 = nn.Linear(256 * 8 * 8, 1)
self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
self.sigmoid = nn.Sigmoid()
def forward(self, input):
x = self.leaky_relu(self.conv1(input))
x = self.leaky_relu(self.conv2(x))
x = self.leaky_relu(self.conv3(x))
x = x.view(-1, 256 * 8 * 8)
x = self.sigmoid(self.fc1(x))
return x
判别器类用作二元分类器,区分真实视频帧和生成的视频帧。目的是评估视频帧的真实性,从而指导生成器产生更真实的输出。该网络由卷积层 (nn.Conv2d) 组成,这些卷积层从输入视频帧中提取分层特征, Leaky ReLU 激活 (nn.LeakyReLU) 增加非线性,同时允许负值的小梯度。
然后,特征图被展平并通过完全连接层 (nn.Linear),最终以 S 形激活 (nn.Sigmoid) 输出指示帧是真实还是假的概率
分数。
通过训练判别器准确地对帧进行分类,生成器同时接受训练以创建更令人信服的视频帧,从而骗过判别器。
我们必须设置用于训练 GAN 的基础组件,例如损失函数、优化器等。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
all_prompts = [prompt for prompt, _, _ in prompts_and_movements]
vocab = {word: idx for idx, word in enumerate(set(" ".join(all_prompts).split()))}
vocab_size = len(vocab)
embed_size = 10
def encode_text(prompt):
return torch.tensor([vocab[word] for word in prompt.split()])
text_embedding = TextEmbedding(vocab_size, embed_size).to(device)
netG = Generator(embed_size).to(device)
netD = Discriminator().to(device)
criterion = nn.BCELoss().to(device)
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
这是我们必须转换代码以在 GPU 上运行的部分(如果可用)。我们已经编写了代码来查找 vocab_size,并且我们正在为生成器和判别器使用 ADAM 优化器。你可以选择自己的优化器。在这里,我们将学习率设置为较小的值 0.0002,嵌入大小为 10,这比其他可供公众使用的 Hugging Face 模型要小得多。
就像其他神经网络一样,我们将以类似的方式对 GAN 架构训练进行编码。
num_epochs = 13
for epoch in range(num_epochs):
for i, (data, prompts) in enumerate(dataloader):
real_data = data.to(device)
prompts = [prompt for prompt in prompts]
netD.zero_grad()
batch_size = real_data.size(0)
labels = torch.ones(batch_size, 1).to(device)
output = netD(real_data)
lossD_real = criterion(output, labels)
lossD_real.backward()
noise = torch.randn(batch_size, 100).to(device)
text_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts])
fake_data = netG(noise, text_embeds)
labels = torch.zeros(batch_size, 1).to(device)
output = netD(fake_data.detach())
lossD_fake = criterion(output, labels)
lossD_fake.backward()
optimizerD.step()
netG.zero_grad()