专栏名称: 计算机视觉研究院
主要由来自于大学的研究生组成的团队,本平台从事机器学习与深度学习领域,主要在人脸检测与识别,多目标检测研究方向。本团队想通过计算机视觉战队平台打造属于自己的品牌,让更多相关领域的人了解本团队,结识更多相关领域的朋友,一起来学习,共同进步!
目录
相关文章推荐
清廉蓉城  ·  当吒言遇上“纪”语…… ·  2 天前  
成都本地宝  ·  今晚24时!四川油价将调整! ·  4 天前  
成都本地宝  ·  成都能待一整天的6个室内场馆!部分免费! ·  4 天前  
51好读  ›  专栏  ›  计算机视觉研究院

从零开始,用英伟达T4、A10训练小型文生视频模型,几小时搞定

计算机视觉研究院  · 公众号  ·  · 2024-07-12 11:05

正文

点击蓝字


关注我们

关注并星标

从此不迷路

计算机视觉研究院


公众号ID 计算机视觉研究院

学习群 扫码在主页获取加入方式

计算机视觉研究院专栏

Column of Computer Vision Institute

很翔实的一篇教 程。
OpenAI 的 Sora、Stability AI 的 Stable Video Diffusion 以及许多其他已经发布或未来将出现的文本生成视频模型,是继大语言模型 (LLM) 之后 2024 年最流行的 AI 趋势之一。

在这篇博客中,作者将展示如何将从头开始构建一个小规模的文本生成视频模型,涵盖了从理解理论概念、到编写整个架构再到生成最终结果的所有内容。

由于作者没有大算力的 GPU,所以仅编写了小规模架构。以下是在不同处理器上训练模型所需时间的比较。


作者表示,在 CPU 上运行显然需要更长的时间来训练模型。如果你需要快速测试代码中的更改并查看结果,CPU 不是最佳选择。因此建议使用 Colab 或 Kaggle 的 T4 GPU 进行更高效、更快速的训练。

构建目标

我们采用了与传统机器学习或深度学习模型类似的方法,即在数据集上进行训练,然后在未见过数据上进行测试。在文本转视频的背景下,假设有一个包含 10 万个狗捡球和猫追老鼠视频的训练数据集,然后训练模型来生成猫捡球或狗追老鼠的视频。

图源:iStock, GettyImages

虽然此类训练数据集在互联网上很容易获得,但所需的算力极高。因此,我们将使用由 Python 代码生成的移动对象视频数据集。同时使用 GAN(生成对抗网络)架构来创建模型,而不是 OpenAI Sora 使用的扩散模型。

我们也尝试使用扩散模型,但内存要求超出了自己的能力。另一方面,GAN 可以更容易、更快地进行训练和测试。

准备条件

我们将使用 OOP(面向对象编程),因此必须对它以及神经网络有基本的了解。此外 GAN(生成对抗网络)的知识不是必需的,因为这里简单介绍它们的架构。

  • OOP:https://www.youtube.com/watch?v=q2SGW2VgwAM
  • 神经网络理论:https://www.youtube.com/watch?v=Jy4wM2X21u0
  • GAN 架构:https://www.youtube.com/watch?v=TpMIssRdhco
  • Python 基础:https://www.youtube.com/watch?v=eWRfhZUzrAc

了解 GAN 架构

什么是 GAN?

生成对抗网络是一种深度学习模型,其中两个神经网络相互竞争:一个从给定的数据集创建新数据(如图像或音乐),另一个则判断数据是真实的还是虚假的。这个过程一直持续到生成的数据与原始数据无法区分。

真实世界应用

  • 生成图像:GAN 根据文本 prompt 创建逼真的图像或修改现有图像,例如增强分辨率或为黑白照片添加颜色。
  • 数据增强:GAN 生成合成数据来训练其他机器学习模型,例如为欺诈检测系统创建欺诈交易数据。
  • 补充缺失信息:GAN 可以填充缺失数据,例如根据地形图生成地下图像以用于能源应用。
  • 生成 3D 模型:GAN 将 2D 图像转换为 3D 模型,在医疗保健等领域非常有用,可用于为手术规划创建逼真的器官图像。

GAN 工作原理

GAN 由两个深度神经网络组成:生成器和判别器。这两个网络在对抗设置中一起训练,其中一个网络生成新数据,另一个网络评估数据是真是假。


GAN 训练示例

让我们以图像到图像的转换为例,解释一下 GAN 模型,重点是修改人脸。

1. 输入图像:输入图像是一张真实的人脸图像。
2. 属性修改:生成器会修改人脸的属性,比如给眼睛加上墨镜。
3. 生成图像:生成器会创建一组添加了太阳镜的图像。
4. 判别器的任务:判别器接收到混合的真实图像(带有太阳镜的人)和生成的图像(添加了太阳镜的人脸)。
5. 评估:判别器尝试区分真实图像和生成图像。
6. 反馈回路:如果判别器正确识别出假图像,生成器会调整其参数以生成更逼真的图像。如果生成器成功欺骗了判别器,判别器会更新其参数以提高检测能力。

通过这一对抗过程,两个网络都在不断改进。生成器越来越善于生成逼真的图像,而判别器则越来越善于识别假图像,直到达到平衡,判别器再也无法区分真实图像和生成的图像。此时,GAN 已成功学会生成逼真的修改图像。

设置背景

我们将使用一系列 Python 库,让我们导入它们。

# Operating System module for interacting with the operating systemimport os
# Module for generating random numbersimport random
# Module for numerical operationsimport numpy as np
# OpenCV library for image processingimport cv2
# Python Imaging Library for image processingfrom PIL import Image, ImageDraw, ImageFont
# PyTorch library for deep learningimport torch
# Dataset class for creating custom datasets in PyTorchfrom torch.utils.data import Dataset
# Module for image transformationsimport torchvision.transforms as transforms
# Neural network module in PyTorchimport torch.nn as nn
# Optimization algorithms in PyTorchimport torch.optim as optim
# Function for padding sequences in PyTorchfrom torch.nn.utils.rnn import pad_sequence
# Function for saving images in PyTorchfrom torchvision.utils import save_image
# Module for plotting graphs and imagesimport matplotlib.pyplot as plt
# Module for displaying rich content in IPython environmentsfrom IPython.display import clear_output, display, HTML
# Module for encoding and decoding binary data to textimport base64
现在我们已经导入了所有的库,下一步就是定义我们的训练数据,用于训练 GAN 架构。

对训练数据进行编码

我们需要至少 10000 个视频作为训练数据。为什么呢?因为我测试了较小数量的视频,结果非常糟糕,几乎没有任何效果。下一个重要问题是:这些视频内容是什么?  我们的训练视频数据集包括一个圆圈以不同方向和不同运动方式移动的视频。让我们来编写代码并生成 10,000 个视频,看看它的效果如何。

# Create a directory named 'training_dataset'os.makedirs('training_dataset', exist_ok=True)
# Define the number of videos to generate for the datasetnum_videos = 10000
# Define the number of frames per video (1 Second Video)frames_per_video = 10
# Define the size of each image in the datasetimg_size = (64, 64)
# Define the size of the shapes (Circle)shape_size = 10

设置一些基本参数后,接下来我们需要定义训练数据集的文本 prompt,并据此生成训练视频。

# Define text prompts and corresponding movements for circlesprompts_and_movements = [ ("circle moving down", "circle", "down"), # Move circle downward ("circle moving left", "circle", "left"), # Move circle leftward ("circle moving right", "circle", "right"), # Move circle rightward ("circle moving diagonally up-right", "circle", "diagonal_up_right"), # Move circle diagonally up-right ("circle moving diagonally down-left" , "circle", "diagonal_down_left"), # Move circle diagonally down-left ("circle moving diagonally up-left", "circle", "diagonal_up_left"), # Move circle diagonally up-left ("circle moving diagonally down-right", "circle", "diagonal_down_right"), # Move circle diagonally down-right ("circle rotating clockwise", "circle", "rotate_clockwise"), # Rotate circle clockwise ("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"), # Rotate circle counter-clockwise ("circle shrinking", "circle", "shrink"), # Shrink circle ("circle expanding", "circle", "expand"), # Expand circle ("circle bouncing vertically", "circle", "bounce_vertical"), # Bounce circle vertically ("circle bouncing horizontally", "circle", "bounce_horizontal"), # Bounce circle horizontally ("circle zigzagging vertically", "circle", "zigzag_vertical"), # Zigzag circle vertically ("circle zigzagging horizontally", "circle", "zigzag_horizontal"), # Zigzag circle horizontally ("circle moving up-left", "circle", "up_left"), # Move circle up-left ("circle moving down-right", "circle", "down_right"), # Move circle down-right ("circle moving down-left", "circle", "down_left"), # Move circle down-left]

我们已经利用这些 prompt 定义了圆的几个运动轨迹。现在,我们需要编写一些数学公式,以便根据 prompt 移动圆。

# Define function with parametersdef create_image_with_moving_shape(size, frame_num, shape, direction): # Create a new RGB image with specified size and white background img = Image.new('RGB', size, color=(255, 255, 255))
# Create a drawing context for the image draw = ImageDraw.Draw(img)
# Calculate the center coordinates of the image center_x, center_y = size[0] // 2, size[1] // 2
# Initialize position with center for all movements position = (center_x, center_y)
# Define a dictionary mapping directions to their respective position adjustments or image transformations direction_map = { # Adjust position downwards based on frame number "down": (0, frame_num * 5 % size[1]), # Adjust position to the left based on frame number "left": (-frame_num * 5 % size[0], 0), # Adjust position to the right based on frame number "right": (frame_num * 5 % size[0], 0), # Adjust position diagonally up and to the right "diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]), # Adjust position diagonally down and to the left "diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]), # Adjust position diagonally up and to the left "diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]), # Adjust position diagonally down and to the right "diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]), # Rotate the image clockwise based on frame number "rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)), # Rotate the image counter-clockwise based on frame number "rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)), # Adjust position for a bouncing effect vertically "bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)), # Adjust position for a bouncing effect horizontally "bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0), # Adjust position for a zigzag effect vertically "zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]), # Adjust position for a zigzag effect horizontally "zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y), # Adjust position upwards and to the right based on frame number "up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]), # Adjust position upwards and to the left based on frame number "up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]), # Adjust position downwards and to the right based on frame number "down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]), # Adjust position downwards and to the left based on frame number "down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]) }
# Check if direction is in the direction map if direction in direction_map: # Check if the direction maps to a position adjustment if isinstance(direction_map[direction], tuple): # Update position based on the adjustment position = tuple(np.add(position, direction_map[direction])) else: # If the direction maps to an image transformation # Update the image based on the transformation img = direction_map[direction]
# Return the image as a numpy array return np.array(img)

上述函数用于根据所选方向在每一帧中移动我们的圆。我们只需在其上运行一个循环,直至生成所有视频的次数。

# Iterate over the number of videos to generatefor i in range(num_videos): # Randomly choose a prompt and movement from the predefined list prompt, shape, direction = random.choice(prompts_and_movements) # Create a directory for the current video video_dir = f'training_dataset/video_{i}' os.makedirs(video_dir, exist_ok=True) # Write the chosen prompt to a text file in the video directory with open(f'{video_dir}/prompt.txt', 'w') as f: f.write(prompt) # Generate frames for the current video for frame_num in range(frames_per_video): # Create an image with a moving shape based on the current frame number, shape, and direction img = create_image_with_moving_shape(img_size, frame_num, shape, direction) # Save the generated image as a PNG file in the video directory cv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)

运行上述代码后,就会生成整个训练数据集。以下是训练数据集文件的结构。


每个训练视频文件夹包含其帧以及对应的文本 prompt。让我们看一下我们的训练数据集样本。

在我们的训练数据集中,我们没有包含圆圈先向上移动然后向右移动的运动。我们将使用这个作为测试 prompt,来评估我们训练的模型在未见过的数据上的表现。


还有一个重要的要点需要注意,我们的训练数据包含许多物体从场景中移出或部分出现在摄像机前方的样本,类似于我们在 OpenAI Sora 演示视频中观察到的情况。


在我们的训练数据中包含此类样本的原因是为了测试当圆圈从角落进入场景时,模型是否能够保持一致性而不会破坏其形状。

现在我们的训练数据已经生成,需要将训练视频转换为张量,这是 PyTorch 等深度学习框架中使用的主要数据类型。此外,通过将数据缩放到较小的范围,执行归一化等转换有助于提高训练架构的收敛性和稳定性。

预处理训练数据

我们必须为文本转视频任务编写一个数据集类,它可以从训练数据集目录中读取视频帧及其相应的文本 prompt,使其可以在 PyTorch 中使用。

# Define a dataset class inheriting from torch.utils.data.Datasetclass TextToVideoDataset(Dataset): def __init__(self, root_dir, transform=None): # Initialize the dataset with root directory and optional transform self.root_dir = root_dir self.transform = transform # List all subdirectories in the root directory self.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))] # Initialize lists to store frame paths and corresponding prompts self.frame_paths = [] self.prompts = []
# Loop through each video directory for video_dir in self.video_dirs: # List all PNG files in the video directory and store their paths frames = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')] self.frame_paths.extend(frames) # Read the prompt text file in the video directory and store its content with open(os.path.join(video_dir, 'prompt.txt'), 'r') as f: prompt = f.read().strip() # Repeat the prompt for each frame in the video and store in prompts list self.prompts.extend([prompt] * len(frames))
# Return the total number of samples in the dataset def __len__(self): return len(self.frame_paths)
# Retrieve a sample from the dataset given an index def __getitem__(self, idx): # Get the path of the frame corresponding to the given index frame_path = self.frame_paths[idx] # Open the image using PIL (Python Imaging Library) image = Image.open(frame_path) # Get the prompt corresponding to the given index prompt = self.prompts[idx]
# Apply transformation if specified if self.transform: image = self.transform(image)
# Return the transformed image and the prompt return image, prompt

在继续编写架构代码之前,我们需要对训练数据进行归一化处理。我们使用 16 的 batch 大小并对数据进行混洗以引入更多随机性。

实现文本嵌入层

你可能已经看到,在 Transformer 架构中,起点是将文本输入转换为嵌入,从而在多头注意力中进行进一步处理。类似地,我们在这里必须编写一个文本嵌入层。基于该层,GAN 架构训练在我们的嵌入数据和图像张量上进行。

# Define a class for text embeddingclass TextEmbedding(nn.Module): # Constructor method with vocab_size and embed_size parameters def __init__(self, vocab_size, embed_size): # Call the superclass constructor super(TextEmbedding, self).__init__() # Initialize embedding layer self.embedding = nn.Embedding(vocab_size, embed_size)
# Define the forward pass method def forward(self, x): # Return embedded representation of input return self.embedding(x)

词汇量将基于我们的训练数据,在稍后进行计算。嵌入大小将为 10。如果使用更大的数据集,你还可以使用 Hugging Face 上已有的嵌入模型。

实现生成器层

现在我们已经知道生成器在 GAN 中的作用,接下来让我们对这一层进行编码,然后了解其内容。

class Generator(nn.Module): def __init__(self, text_embed_size): super(Generator, self).__init__() # Fully connected layer that takes noise and text embedding as input self.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8) # Transposed convolutional layers to upsample the input self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1) self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1) self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1) # Output has 3 channels for RGB images # Activation functions self.relu = nn.ReLU(True) # ReLU activation function self.tanh = nn.Tanh() # Tanh activation function for final output
def forward(self, noise, text_embed): # Concatenate noise and text embedding along the channel dimension x = torch.cat((noise, text_embed), dim=1) # Fully connected layer followed by reshaping to 4D tensor x = self.fc1(x).view(-1, 256, 8, 8) # Upsampling through transposed convolution layers with ReLU activation x = self.relu(self.deconv1(x)) x = self.relu(self.deconv2(x)) # Final layer with Tanh activation to ensure output values are between -1 and 1 (for images) x = self.tanh(self.deconv3(x)) return x

该 Generator 类负责根据随机噪声和文本嵌入的组合创建视频帧,旨在根据给定的文本描述生成逼真的视频帧。该网络从完全连接层 (nn.Linear) 开始,将噪声向量和文本嵌入组合成单个特征向量。然后,该向量被重新整形并经过一系列的转置卷积层 (nn.ConvTranspose2d),这些层将特征图逐步上采样到所需的视频帧大小。

这些层使用 ReLU 激活 (nn.ReLU) 实现非线性,最后一层使用 Tanh 激活 (nn.Tanh) 将输出缩放到 [-1, 1] 的范围。因此,生成器将抽象的高维输入转换为以视觉方式表示输入文本的连贯视频帧。

实现判别器层

在编写完生成器层之后,我们需要实现另一半,即判别器部分。

class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() # Convolutional layers to process input images self.conv1 = nn.Conv2d(3, 64, 4, 2, 1) # 3 input channels (RGB), 64 output channels, kernel size 4x4, stride 2, padding 1 self.conv2 = nn.Conv2d(64, 128, 4, 2, 1) # 64 input channels, 128 output channels, kernel size 4x4, stride 2, padding 1 self.conv3 = nn.Conv2d(128, 256, 4, 2, 1) # 128 input channels, 256 output channels, kernel size 4x4, stride 2, padding 1 # Fully connected layer for classification self.fc1 = nn.Linear(256 * 8 * 8, 1) # Input size 256x8x8 (output size of last convolution), output size 1 (binary classification) # Activation functions self.leaky_relu = nn.LeakyReLU(0.2, inplace=True) # Leaky ReLU activation with negative slope 0.2 self.sigmoid = nn.Sigmoid() # Sigmoid activation for final output (probability)
def forward(self, input): # Pass input through convolutional layers with LeakyReLU activation x = self.leaky_relu(self.conv1(input)) x = self.leaky_relu(self.conv2(x)) x = self.leaky_relu(self.conv3(x)) # Flatten the output of convolutional layers x = x.view(-1, 256 * 8 * 8) # Pass through fully connected layer with Sigmoid activation for binary classification x = self.sigmoid(self.fc1(x)) return x

判别器类用作二元分类器,区分真实视频帧和生成的视频帧。目的是评估视频帧的真实性,从而指导生成器产生更真实的输出。该网络由卷积层 (nn.Conv2d) 组成,这些卷积层从输入视频帧中提取分层特征, Leaky ReLU 激活 (nn.LeakyReLU) 增加非线性,同时允许负值的小梯度。

然后,特征图被展平并通过完全连接层 (nn.Linear),最终以 S 形激活 (nn.Sigmoid) 输出指示帧是真实还是假的概率 分数。

通过训练判别器准确地对帧进行分类,生成器同时接受训练以创建更令人信服的视频帧,从而骗过判别器。

编写训练参数

我们必须设置用于训练 GAN 的基础组件,例如损失函数、优化器等。

# Check for GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create a simple vocabulary for text promptsall_prompts = [prompt for prompt, _, _ in prompts_and_movements] # Extract all prompts from prompts_and_movements listvocab = {word: idx for idx, word in enumerate(set(" ".join(all_prompts).split()))} # Create a vocabulary dictionary where each unique word is assigned an indexvocab_size = len(vocab) # Size of the vocabularyembed_size = 10 # Size of the text embedding vector
def encode_text(prompt): # Encode a given prompt into a tensor of indices using the vocabulary return torch.tensor([vocab[word] for word in prompt.split()])
# Initialize models, loss function, and optimizerstext_embedding = TextEmbedding(vocab_size, embed_size).to(device) # Initialize TextEmbedding model with vocab_size and embed_sizenetG = Generator(embed_size).to(device) # Initialize Generator model with embed_sizenetD = Discriminator().to(device) # Initialize Discriminator modelcriterion = nn.BCELoss().to(device) # Binary Cross Entropy loss functionoptimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999)) # Adam optimizer for DiscriminatoroptimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) # Adam optimizer for Generator

这是我们必须转换代码以在 GPU 上运行的部分(如果可用)。我们已经编写了代码来查找 vocab_size,并且我们正在为生成器和判别器使用 ADAM 优化器。你可以选择自己的优化器。在这里,我们将学习率设置为较小的值 0.0002,嵌入大小为 10,这比其他可供公众使用的 Hugging Face 模型要小得多。

编写训练 loop

就像其他神经网络一样,我们将以类似的方式对 GAN 架构训练进行编码。

# Number of epochsnum_epochs = 13
# Iterate over each epochfor epoch in range(num_epochs): # Iterate over each batch of data for i, (data, prompts) in enumerate(dataloader): # Move real data to device real_data = data.to(device) # Convert prompts to list prompts = [prompt for prompt in prompts]
# Update Discriminator netD.zero_grad() # Zero the gradients of the Discriminator batch_size = real_data.size(0) # Get the batch size labels = torch.ones(batch_size, 1).to(device) # Create labels for real data (ones) output = netD(real_data) # Forward pass real data through Discriminator lossD_real = criterion(output, labels) # Calculate loss on real data lossD_real.backward() # Backward pass to calculate gradients # Generate fake data noise = torch.randn(batch_size, 100).to(device) # Generate random noise text_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts]) # Encode prompts into text embeddings fake_data = netG(noise, text_embeds) # Generate fake data from noise and text embeddings labels = torch.zeros(batch_size, 1).to(device) # Create labels for fake data (zeros) output = netD(fake_data.detach()) # Forward pass fake data through Discriminator (detach to avoid gradients flowing back to Generator) lossD_fake = criterion(output, labels) # Calculate loss on fake data lossD_fake.backward() # Backward pass to calculate gradients optimizerD.step() # Update Discriminator parameters
# Update Generator netG.zero_grad() # Zero the gradients of the Generator






请到「今天看啥」查看全文