为了账号安全,请及时绑定邮箱和手机立即绑定

用Python构建将文本转换为视频的AI模型的全过程

从零开始制作AI生成的视频

紧随着大型语言模型(LLMs),这些文本到视频模型已经成为2024年最热门的人工智能趋势之一,包括来自OpenAI的Sora和来自Stability AI的Stable Video Diffusion。在这篇博客中,我们将从零开始构建一个小规模文本到视频模型。我们将输入一个文本提示,我们的模型将根据该提示生成一个视频。这篇博客将涵盖从理解理论概念到编码整个架构,最终生成视频的所有内容。

我没有高端的GPU,所以我编写了一个小规模的架构。以下是不同处理器上训练此模型所需时间的比较。

在 CPU 上运行显然需要更多的时间来训练模型。如果你想要快速测试代码变更并查看结果,CPU 并不是最佳选择。我建议使用来自 ColabKaggle 的 T4 GPU 来进行更高效且更快的训练。

为了避免从这篇博客复制粘贴代码,这里有一个包含所有代码和相关信息的 Jupyter Notebook 的 GitHub 仓库链接。

GitHub - FareedKhan-dev/AI-text-to-video-model-from-scratch: 在博客里,我们将从头开始做一个小规模的文本转视频模型。我们将输入一段文本提示,…github.com

这里是一个博客链接,它会指导你一步一步地从零开始创建Stable Diffusion:

从零开始实现稳定扩散:使用Python实现扩散模型的逐步指南.levelup.gitconnected.com
目录(内容提要)
  1. 我们正在构建什么?
  2. 前置条件
  3. 了解GAN架构
    ∘ 什么是GAN?
    ∘ 实际应用案例
    ∘ GAN的工作原理是什么?
    ∘ GAN训练示例
  4. 准备阶段
  5. 准备训练数据
  6. 预处理训练资料
  7. 实现文本嵌入层
  8. 实现生成器层
  9. 实现判别器层
  10. 设置训练参数
  11. 编写训练循环
  12. 保存训练模型
  13. 生成AI视频片段
  14. 还有哪些不足?
  15. 关于作者
我们在建什么

我们将采用类似于传统机器学习或深度学习模型的方法,这些模型在一个数据集上进行训练,然后在未见过的数据上进行测试。在文本到视频的场景中,假设我们有一个包含10万段狗狗捡球和猫咪捉老鼠视频的数据集。我们将训练模型来生成猫咪捡球或狗狗捉老鼠的视频。

这些视频来自 iStock 网站和 GettyImages 网站。

虽然这样的训练数据集在网上很容易找到,但需要的计算资源非常庞大。因此,我们将使用一个由Python代码生成的移动物体视频数据集进行工作。

我们将使用对抗生成网络(GAN架构)来创建我们的模型,而不是OpenAI Sora使用的扩散模型架构。我尝试使用扩散模型架构,但它因为需要太多内存而崩溃,超出了我的承受能力。相比之下,GAN更容易且更快地进行训练和测试。

前提条件:

我们将使用面向对象编程(OOP),您需要对面向对象编程(OOP)和神经网络有一定的基础理解。了解生成对抗网络(GANs)不是强制性的,因为我们将在这里介绍其架构。

理解GAN架构

了解GAN很重要,因为很多架构都依赖于它。我们来了解一下GAN是什么,它的组成部分,以及更多细节。

什么是生成对抗网络(即Generative Adversarial Network)?

生成对抗网络(GAN,Generative Adversarial Network)是一种深度学习模型,其中两个神经网络对抗:一个网络基于给定的数据集生成新的数据样本,而另一个网络试图分辨这些数据是真实的还是伪造的。这个对抗过程会持续进行,直到生成的数据和原始数据无法区分。

实际应用案例
  1. 生成图像:GANs可以从文本提示生成逼真的图像,或者修改现有图像,例如提高分辨率或给黑白照片上色。
  2. 数据增强:它们生成合成数据来增强其他机器学习模型的训练,例如为欺诈检测系统生成欺诈交易数据。
  3. 补齐缺失信息:GANs可以填补缺失的数据,例如通过地形图生成地下图像,用于能源领域的应用。
  4. 生成3D模型:它们将2D图像转换为3D模型,在医疗等领域非常有用,可以为手术计划生成逼真的器官图像。
GAN(生成对抗网络)是怎么工作的?

它由两个深度神经网络组成:生成器判别器。这两个网络在一个对抗的环境中一起训练,其中一个生成新的数据,另一个判断数据是否真实。

这是GAN工作方式的一个简单解释:

  1. 训练集分析:生成器分析训练集以识别数据的属性,而判别器独立分析相同的数据以学习其属性。
  2. 数据修改:生成器向数据的某些属性添加噪音(随机变化)。
  3. 数据传递:修改后的数据随后传递给判别器。
  4. 概率计算:判别器计算生成的数据属于原始数据集的可能性。
  5. 反馈循环:判别器向生成器提供反馈,引导生成器在下一次循环中减少随机噪声。
  6. 对抗训练:生成器试图最大化判别器的错误,而判别器试图最小化自己的错误。经过许多训练迭代,两个网络不断改进和进化。
  7. 平衡状态:训练继续进行,直到判别器无法区分真实数据与合成数据,表明生成器已成功学会生成逼真的数据了。此时,训练过程完成。

来自AWS Guide

GAN示例

让我们通过一个图像到图像的转换例子来具体说明GAN模型,重点是修改人脸的照片。

  1. 输入图像:输入的是一张真实的人脸照片。
  2. 属性修改:生成器修改人脸属性,比如在眼睛位置添加太阳镜。
  3. 生成图像:生成器会生成一些带太阳镜的人脸图像。
  4. 判别器任务:判别器收到一些戴太阳镜的人的真实照片和添加了太阳镜的人脸图像。
  5. 评估:判别器试图分辨哪些是真实照片,哪些是生成的图像。
  6. 反馈回路:如果判别器能准确识别出假图像,生成器会调整参数,使之能生成更逼真的图像。如果生成器成功骗过判别器,判别器会调整参数,提高识别能力。

通过这个对抗的过程,两个网络不断改进。生成器越来越擅长创造逼真的图像,而判别器也越来越擅长分辨假图像,直到达到一种平衡。此时,判别器再也无法分辨出真实图像和生成的图像。此时,生成对抗网络(GAN)已成功学会生成逼真的图像。

设定

我们将用到一系列Python库,现在来导入它们。

    # 操作系统模块,用于与操作系统交互  
    import os  

    # 生成随机数的库  
    import random  

    # 数值运算库  
    import numpy as np  

    # 用于图像处理的OpenCV库  
    import cv2  

    # PIL图像处理库  
    from PIL import Image, ImageDraw, ImageFont  

    # 用于深度学习的PyTorch库  
    import torch  

    # PyTorch中的自定义数据集类  
    from torch.utils.data import Dataset  

    # 图像变换模块  
    import torchvision.transforms as transforms  

    # PyTorch神经网络模块  
    import torch.nn as nn  

    # PyTorch优化算法模块  
    import torch.optim as optim  

    # PyTorch填充序列函数  
    from torch.nn.utils.rnn import pad_sequence  

    # PyTorch保存图像函数  
    from torchvision.utils import save_image  

    # 绘制图形和图像的模块  
    import matplotlib.pyplot as plt  

    # IPython环境中的显示丰富内容模块  
    from IPython.display import clear_output, display, HTML  

    # 二进制数据编码解码模块  
    import base64

现在我们已经导入了所有的库,接下来就是定义用来训练GAN架构的训练数据了。

编写训练数据代码

我们需要至少10,000个视频来进行训练。因为我用较少数量进行测试,结果非常差,几乎看不到任何效果。接下来最重要的问题是:这些视频是关于什么的?我们的训练视频数据集包括一个圆在不同方向和以不同方式移动。所以,让我们写代码来生成10,000个这样的视频看看效果如何。

    # 创建名为 'training_dataset' 的文件夹  
    os.makedirs('training_dataset', exist_ok=True)  

    # 定义数据集生成的视频数量  
    num_videos = 10000  

    # 定义每段视频为1秒的帧数  
    frames_per_video = 10  

    # 定义数据集中每个图像的尺寸  
    img_size = (64, 64)  

    # 定义圆形的大小  
    shape_size = 10

设置了一些基本参数之后,我们需要根据训练数据集定义训练文本提示内容,训练视频也将根据这些提示生成。

    # 定义文本提示及其对应的圆圈运动  
    prompts_and_movements = [  
        ("圆圈向下移动", "circle", "down"),  
        ("圆圈向左移动", "circle", "left"),  
        ("圆圈向右移动", "circle", "right"),  
        ("圆圈沿右上对角线移动", "circle", "diagonal_up_right"),  
        ("圆圈沿左下对角线移动", "circle", "diagonal_down_left"),  
        ("圆圈沿左上对角线移动", "circle", "diagonal_up_left"),  
        ("圆圈沿右下对角线移动", "circle", "diagonal_down_right"),  
        ("圆圈顺时针旋转", "circle", "rotate_clockwise"),  
        ("圆圈逆时针旋转", "circle", "rotate_counter_clockwise"),  
        ("圆圈缩小", "circle", "shrink"),  
        ("圆圈放大", "circle", "expand"),  
        ("圆圈垂直弹跳", "circle", "bounce_vertical"),  
        ("圆圈水平弹跳", "circle", "bounce_horizontal"),  
        ("圆圈沿垂直之字形移动", "circle", "zigzag_vertical"),  
        ("圆圈沿水平之字形移动", "circle", "zigzag_horizontal"),  
        ("圆圈向左上方移动", "circle", "up_left"),  
        ("圆圈向右下方移动", "circle", "down_right"),  
        ("圆圈向左下方移动", "circle", "down_left"),  
    ]

我们已经定义了几种运动方式。我们现在需要编写一些数学公式来根据这些提示来移动这个圆。

    # 定义带有参数的函数  
    def create_image_with_moving_shape(size, frame_num, shape, direction):  

        # 创建指定大小的 RGB 图像,背景为白色填充  
        img = Image.new('RGB', size, color=(255, 255, 255))    

        # 为图像创建绘图上下文对象  
        draw = ImageDraw.Draw(img)    

        # 计算图像的中心点坐标  
        center_x, center_y = size[0] // 2, size[1] // 2    

        # 使用中心点作为所有运动的起始位置  
        position = (center_x, center_y)    

        # 定义一个字典,将方向映射到相应的位置调整或图像变换  
        direction_map = {    
            # 向下调整位置  
            "down": (0, frame_num * 5 % size[1]),    
            # 向左调整位置  
            "left": (-frame_num * 5 % size[0], 0),    
            # 向右调整位置  
            "right": (frame_num * 5 % size[0], 0),    
            # 对角线上调至右边  
            "diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),    
            # 对角线下调至左边  
            "diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]),    
            # 对角线上调至左边  
            "diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),    
            # 对角线下调至右边  
            "diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),    
            # 顺时针旋转图像  
            "rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),    
            # 逆时针旋转图像  
            "rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),    
            # 垂直跳跃效果位置调整  
            "bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)),    
            # 水平跳跃效果位置调整  
            "bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0),    
            # 垂直锯齿效果位置调整  
            "zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]),    
            # 水平锯齿效果位置调整  
            "zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y),    
            # 向上并向右移动位置  
            "up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),    
            # 向上并向左移动位置  
            "up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),    
            # 向下并向右移动位置  
            "down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),    
            # 向下并向左移动位置  
            "down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1])    
        }  

        # 检查给定的方向是否在方向映射中  
        if direction in direction_map:    
            # 检查方向是否映射到位置调整  
            if isinstance(direction_map[direction], tuple):    
                # 根据调整更新位置坐标  
                position = tuple(np.add(position, direction_map[direction]))    
            else:  # 如果方向映射到图像变换  
                # 根据变换更新图像对象  
                img = direction_map[direction]    

        # 将图像转换为 numpy 数组并返回  
        return np.array(img)

上面这个函数用于根据选定的方向,每帧移动我们的圆圈。具体来说,我们只需要在其上运行一个循环,循环次数为视频的数量,从而生成所有视频。

# 生成视频的数量  
for i in range(num_videos):  
    # 从预定义的列表中随机选择一个提示和运动  
    prompt, shape, direction = random.choice(prompts_and_movements)  

    # 为当前视频创建一个目录  
    video_dir = f'training_dataset/video_{i}'  
    os.makedirs(video_dir, exist_ok=True)  

    # 将选定的提示写入到视频目录下的文本文件中  
    with open(f'{video_dir}/prompt.txt', 'w') as f:  
        f.write(prompt)  

    # 为当前视频生成帧  
    for frame_num in range(frames_per_video):  
        # 根据当前帧数、形状和方向生成一个包含移动形状的图像  
        img = create_image_with_moving_shape(img_size, frame_num, shape, direction)  

        # 将生成的图像保存为PNG格式的文件到视频目录下  
        cv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)

运行这段代码后,它将生成我们的整个训练数据集。以下是训练数据集文件的结构,是这样的。

每个训练视频文件夹不仅包含视频帧,还包括相应的文本提示。我们来看看数据集的一个样本。

在我们的训练数据集中,我们没有包含圆圈先向上再向右移动的过程。我们将用这个作为测试提示,来评估我们的模型在新数据上的表现。

另一个重要点是,我们的训练数据中包含了大量样本,这些样本中的对象要么逐渐远离摄像机,要么部分遮挡摄像机,这与我们在OpenAI Sora演示视频中看到的情形相似。

我们将这些样本包含在训练数据中的原因是测试我们的模型是否在圆形从画面一角进入时能够保持其完整性且不破坏其形状。

现在我们有了生成的训练数据,我们需要将训练视频转换成张量(tensor),这是像PyTorch这样的深度学习框架中主要的数据类型。此外,进行归一化等转换有助于通过将数据缩放到较小的范围内,提高训练架构的收敛性和稳定性。

处理我们的训练数据

我们需要创建一个用于文本转视频的任务的数据集类,可以从训练数据集文件夹中读取视频帧及其相应的文本提示,使其能够在PyTorch中使用。

    # 定义一个继承自 torch.utils.data.Dataset 的数据集类 TextToVideoDataset  
    class TextToVideoDataset(Dataset):  
        def __init__(self, root_dir, transform=None):  
            # 使用根目录和可选的转换方法初始化  
            self.root_dir = root_dir  
            self.transform = transform  
            # 列出根目录下的所有子目录,并将它们存储在 self.video_dirs 中  
            self.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]  
            # 初始化列表以存储帧路径和提示  
            self.frame_paths = []  
            self.prompts = []  

            # 遍历每个视频目录  
            for video_dir in self.video_dirs:  
                # 列出视频目录下的所有 PNG 文件,存储所有 PNG 文件的路径  
                frames = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')]  
                self.frame_paths.extend(frames)  
                # 读取视频目录下的 prompt.txt 文件,并将其内容存储在提示列表中  
                with open(os.path.join(video_dir, 'prompt.txt'), 'r') as f:  
                    prompt = f.read().strip()  
                # 对于视频中的每个帧,重复提示并存储在提示列表中  
                self.prompts.extend([prompt] * len(frames))  

        # 返回数据集中的样本数目  
        def __len__(self):  
            return len(self.frame_paths)  

        # 根据索引从数据集中获取样本  
        def __getitem__(self, idx):  
            # 获取对应索引的帧路径  
            frame_path = self.frame_paths[idx]  
            # 使用 PIL(Python Imaging 库)打开图像  
            image = Image.open(frame_path)  
            # 获取对应索引的提示  
            prompt = self.prompts[idx]  

            # 如果指定了转换方法,则应用该转换  
            if self.transform:  
                image = self.transform(image)  

            # 返回转换后的图像和对应的提示  
            return image, prompt

在开始编写架构代码之前,我们需要对训练数据进行规范化。我们将使用批大小为16,并将数据打乱以增加更多的随机性。

    # 定义要对数据执行的一组转换
    transform = transforms.Compose([
        transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为张量
        transforms.Normalize((0.5,), (0.5,)) # 使用均值和标准差0.5对图像进行标准化
    ])

    # 使用上述定义的转换加载数据集
    dataset = TextToVideoDataset(root_dir='training_dataset', transform=transform)
    # 创建一个数据加载器来加载数据集
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
实现文本嵌入模块的功能

你可能在变压器架构中见过,第一步是将我们的文本输入转换为嵌入向量,以便在多头注意力机制中进一步处理这些嵌入向量。与此类似,在这里我们需要实现一个文本嵌入层,基于这个嵌入层,我们将使用生成对抗网络对嵌入数据和图像张量进行训练。

    # 定义一个文本嵌入的类  
    class TextEmbedding(nn.Module):  
        # 一个带有 vocab_size 和 embed_size 参数的构造方法  
        def __init__(self, vocab_size, embed_size):  
            # 调用父类构造方法  
            super(TextEmbedding, self).__init__()  
            # 初始化嵌入层,用于将词汇表中的索引转换为嵌入向量  
            self.embedding = nn.Embedding(vocab_size, embed_size)  

        # 定义前向方法  
        def forward(self, x):  
            # 返回输入的嵌入  
            return self.embedding(x)

词汇表的大小将根据我们的训练数据确定,我们将在后面进行计算。嵌入维度为10。如果您使用更大的数据集,您还可以使用Hugging Face上提供的各种嵌入模型。

创建生成层

既然我们已经知道了生成器在GAN中的作用,让我们写这个层的代码,然后理解内容。

    class Generator(nn.Module):  
        def __init__(self, text_embed_size):  
            super(Generator, self).__init__()  

            # 噪声和文本嵌入作为输入的全连接层  
            self.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8)  

            # 通过转置卷积层实现的上采样层  
            self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)  
            self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)  
            self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1)  # 输出有3个通道,用于RGB图像的色彩表示  

            # 激活单元  
            self.relu = nn.ReLU(True)  # ReLU激活单元  
            self.tanh = nn.Tanh()       # 最终输出的Tanh激活单元,确保输出值在-1和1之间(用于图像)  

        def forward(self, noise, text_embed):  
            # 沿通道维度拼接噪声和文本嵌入  
            x = torch.cat((noise, text_embed), dim=1)  

            # 经过全连接层后,重塑为4维张量  
            x = self.fc1(x).view(-1, 256, 8, 8)  

            # 通过带有ReLU激活的转置卷积层进行上采样  
            x = self.relu(self.deconv1(x))  
            x = self.relu(self.deconv2(x))  

            # 最终层使用Tanh激活单元,确保输出值在-1和1之间  
            x = self.tanh(self.deconv3(x))  

            return x

这个 Generator 类的作用是负责从随机噪声和文本嵌入的组合中生成视频帧。它的目标是根据给定的文本描述生成逼真的视频帧。网络开始于一个全连接层(nn.Linear),其中将噪声向量与文本嵌入结合成一个特征向量。然后,这个特征向量经过重塑并通过一系列转置卷积层(nn.ConvTranspose2d),逐步将特征图上采样至所需的视频帧大小。

各层使用ReLU激活函数(nn.ReLU)来引入非线性,而最后一层使用Tanh激活函数(nn.Tanh)将输出范围限制在[-1, 1]。因此,生成器将抽象的高维输入转换为以视觉形式展现输入文本的连贯的视频帧。

搭建辨别器层

完成生成器层的编码之后,接下来我们要实现的就是判别器部分。

    class Discriminator(nn.Module):  
        def __init__(self):  
            super(Discriminator, self).__init__()  

            # 用于处理输入图像的卷积层  
            self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)   # 3个输入通道(RGB),64个输出通道,4x4的卷积核,步长2,填充1  
            self.conv2 = nn.Conv2d(64, 128, 4, 2, 1) # 64个输入通道,128个输出通道,4x4的卷积核,步长2,填充1  
            self.conv3 = nn.Conv2d(128, 256, 4, 2, 1) # 128个输入通道,256个输出通道,4x4的卷积核,步长2,填充1  

            # 全连接层进行分类  
            self.fc1 = nn.Linear(256 * 8 * 8, 1)  # 输入尺寸256x8x8(最后一个卷积层的输出尺寸),输出大小1(二分类)  

            # 激活函数  
            self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)  # 负斜率为0.2的Leaky ReLU激活函数  
            self.sigmoid = nn.Sigmoid()  # 用于计算最终输出概率的Sigmoid激活函数  

        def forward(self, input):  
            # 输入经过带有LeakyReLU激活函数的卷积层  
            x = self.leaky_relu(self.conv1(input))  
            x = self.leaky_relu(self.conv2(x))  
            x = self.leaky_relu(self.conv3(x))  

            # 将卷积层的输出展平  
            x = x.view(-1, 256 * 8 * 8)  

            # 通过全连接层并使用Sigmoid激活函数进行二分类  
            x = self.sigmoid(self.fc1(x))  

            return x

鉴别器类作为二分类器,区分真实与生成的视频帧。其目的是评估视频帧的真实性,从而指导生成器生成更逼真的视频。该网络由nn.Conv2d卷积层组成,用于从输入视频帧中抽取分层特征,并采用Leaky ReLU激活函数(nn.LeakyReLU)添加非线性,同时允许负值的小梯度。然后将特征图展平并通过一个nn.Linear全连接层,最终通过nn.Sigmoid激活函数输出一个概率分数。

通过训练判别器准确区分帧,生成器也被训练生成更逼真的视频画面,它试图欺骗判别器。

编码训练参数设置

我们必须为训练GAN建立基础部分,比如损失函数、优化器等。

    # 检查GPU  
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  

    # 创建一个简单的词汇表用于文本提示词  
    all_prompts = [prompt for prompt, _, _ in prompts_and_movements]  # 从prompts_and_movements列表中提取所有提示词  
    vocab = {word: idx for idx, word in enumerate(set(" ".join(all_prompts).split()))}  # 创建一个用于文本提示词的词汇表,其中每个唯一的单词都被分配一个索引  
    vocab_size = len(vocab)  # 词汇表大小  
    embed_size = 10  # 文本嵌入向量大小  

    def encode_text(prompt):  
        # 使用词汇表将给定的提示词编码为索引张量  
        return torch.tensor([vocab[word] for word in prompt.split()])  

    # 初始化模型、损失函数和优化器  
    text_embedding = TextEmbedding(vocab_size, embed_size).to(device)  # 使用vocab_size和embed_size初始化TextEmbedding模型  
    netG = Generator(embed_size).to(device)  # 使用embed_size初始化生成器模型  
    netD = Discriminator().to(device)  # 初始化判别器模型  
    criterion = nn.BCELoss().to(device)  # 二元交叉熵损失函数  
    optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))  # 用于判别器的Adam优化器  
    optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))  # 用于生成器的Adam优化器

这是我们将代码调整为能够在可用的GPU上运行的部分。我们已经编写了代码来确定vocab_size,并为生成器和判别器都使用了ADAM优化器。如果您愿意,可以选择您自己的优化器。我们将学习率设置为0.0002,嵌入大小为10,这比大多数公开可用的Hugging Face模型要小得多。

编码训练循环

就像其他的神经网络一样,我们也将以类似的方式训练GAN。

    # 训练周期数  
    num_epochs = 13  

    # 每个训练周期进行迭代  
    for epoch in range(num_epochs):  
        # 遍历每个数据批  
        for i, (data, prompts) in enumerate(dataloader):  
            # 将真实数据传输到设备上  
            real_data = data.to(device)  

            # 将提示转换为列表形式  
            prompts = [prompt for prompt in prompts]  

            # 更新辨别器(判别网络)  
            netD.zero_grad()  # 将辨别器(判别网络)的梯度清零  
            batch_size = real_data.size(0)  # 获取批次大小  
            labels = torch.ones(batch_size, 1).to(device)  # 为真实数据创建标签(值为1)  
            output = netD(real_data)  # 将真实数据输入判别器  
            lossD_real = criterion(output, labels)  # 计算真实数据的损失值  
            lossD_real.backward()  # 进行反向传播以计算梯度  

            # 生成随机噪声向量  
            noise = torch.randn(batch_size, 100).to(device)  # 生成随机噪声向量  
            text_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts])  # 将提示编码为文本嵌入向量  
            fake_data = netG(noise, text_embeds)  # 使用噪声向量和文本嵌入向量生成假数据  
            labels = torch.zeros(batch_size, 1).to(device)  # 为假数据创建标签(值为0)  
            output = netD(fake_data.detach())  # 将假数据送入判别器(断开以避免生成器梯度流动)  
            lossD_fake = criterion(output, labels)  # 计算假数据的损失值  
            lossD_fake.backward()  # 进行反向传播以计算梯度  
            optimizerD.step()  # 更新辨别器(判别网络)的参数  

            # 更新生成器的参数  
            netG.zero_grad()  # 将生成器的梯度清零  
            labels = torch.ones(batch_size, 1).to(device)  # 为假数据创建标签(值为1)以欺骗辨别器(判别网络)  
            output = netD(fake_data)  # 将假数据(现在更新)输入判别器  
            lossG = criterion(output, labels)  # 根据判别器的输出计算生成器的损失  
            lossG.backward()  # 进行反向传播以计算梯度  
            optimizerG.step()  # 更新生成器的参数  

        # 打印每个周期的损失信息  
        print(f"Epoch [{epoch + 1}/{num_epochs}] 判别器损失: {lossD_real + lossD_fake}, 生成器损失: {lossG}")

通过反向传播,我们的损失函数将会同时调整生成器和判别器。我们为训练周期使用了13个周期。我测试了不同的值,但结果表明如果周期超过这个数目,结果差异不大。此外,存在较高的过拟合可能性。如果我们有更多样化的数据集,包含更多的动作和形状,可以考虑使用更多的周期,但在这种情况下,我们不建议这样做。

运行这段代码,它就开始训练起来,每个 epoch 结束时输出生成器和判别器的损失值。

    ## 输出 ##

    轮次 [1/13] D 损失: 0.8798642754554749, G 损失: 1.300612449645996  
    轮次 [2/13] D 损失: 0.8235711455345154, G 损失: 1.3729925155639648  
    轮次 [3/13] D 损失: 0.6098687052726746, G 损失: 1.3266581296920776  

    等
保存训练模型

训练结束后,我们需要保存训练得到的GAN模型的判别器和生成器,这只需要两行代码就能搞定。

    # 将生成器模型的状态字典保存为名为 'generator.pth' 的文件  
    torch.save(netG.state_dict(), 'generator.pth')  

    # 将判别器模型的状态字典保存为名为 'discriminator.pth' 的文件  
    torch.save(netD.state_dict(), 'discriminator.pth')
制作AI视频

正如我们讨论的,我们在未见数据上测试模型的方法类似于这样的例子:我们的训练数据涉及狗捡取球和猫追老鼠。因此,我们的测试场景可以涉及类似的情况,例如猫追球或狗追老鼠。

在我们的特定情况下,这种运动模式不在我们的训练数据中出现,所以模型对这种特定运动感到陌生。然而,它已经接受过其他运动的训练。我们可以用这种运动作为提示来测试我们的模型,并看看它如何表现。

    # 推理函数,根据给定的文本提示生成视频  
    def generate_video(text_prompt, num_frames=10):  
        # 根据文本提示创建用于存储视频帧的目录  
        os.makedirs(f'generated_video_{text_prompt.replace(" ", "_")}', exist_ok=True)  

        # 将文本提示编码为文本嵌入向量  
        text_embed = text_embedding(encode_text(text_prompt).to(device)).mean(dim=0).unsqueeze(0)  

        # 生成每一帧视频  
        for frame_num in range(num_frames):  
            # 生成随机噪声向量  
            noise = torch.randn(1, 100).to(device)  

            # 使用生成器网络生成假帧图像  
            with torch.no_grad():  
                fake_frame = netG(noise, text_embed)  

            # 保存生成的假帧图像为文件  
            save_image(fake_frame, f'generated_video_{text_prompt.replace(" ", "_")}/frame_{frame_num}.png')  

    # 使用generate_video函数并指定特定文本提示  
    generate_video('circle moving up-right')

当我们运行这段代码时,它将生成一个包含所有视频帧的目录。我们需要编写一些代码将所有这些帧合成为一个短的短视频。

    # 定义存放PNG帧的文件夹路径  
    folder_path = 'generated_video_circle_moving_up-right'  

    # 获取文件夹中的所有PNG文件的列表  
    image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]  

    # 按文件名排序图像(假设它们是按顺序命名的)  
    image_files.sort()  

    # 创建一个列表来存储帧  
    frames = []  

    # 读取每个图像并将图像追加到帧列表中  
    for image_file in image_files:  
      image_path = os.path.join(folder_path, image_file)  
      frame = cv2.imread(image_path)  
      frames.append(frame)  

    # 将帧列表转换为numpy数组,以便于处理  
    frames = np.array(frames)  

    # 定义帧率(每秒帧数,fps)  
    fps = 10  

    # 创建一个视频写入器  
    fourcc = cv2.VideoWriter_fourcc(*'XVID')  
    out = cv2.VideoWriter('generated_video.avi', fourcc, fps, (frames[0].shape[1], frames[0].shape[0]))  

    # 将每个帧写入视频  
    for frame in frames:  
      out.write(frame)  

    # 释放视频写入器  
    out.release()

确保文件夹路径指向你新生成的视频所在的文件夹。运行这段代码后,你的AI视频将会创建成功。让我们看看它是怎样的。

我多次使用相同数量的 epochs 进行了训练。在这两次尝试中,圆圈都是从底部开始出现一半。好的一面是,在这两种情况下,我们的模型都尝试执行向上和向右的动作。比如说,在第一次尝试中,圆圈先是对角线上移,然后继续向上移动,而在第二次尝试中,圆圈在对角线移动的同时缩小了大小。圆圈都没有向左移动或完全消失,这表明我们的模型表现不错。

缺什么?

我测试了这种架构的各种方面,发现训练数据最关键。在数据集中加入更多的动作和形状,可以增加多样性和提高模型性能。因为数据是通过代码生成的,生成更多样化的数据不费时间;相反,你可以更专注于优化逻辑。

此外,这篇博客中讨论的GAN架构相对简单。你可以通过采用更高级的技术或使用语言模型(LLM)嵌入而不是简单的神经网络嵌入来使其更加复杂。此外,调整嵌入大小等参数可以显著影响模型的效果。

关于我

我拥有数据科学的硕士学位(MSc),在自然语言处理和人工智能领域工作已经超过两年了。你可以雇用我,或者问我任何关于人工智能的问题!不管是什么问题,我都会回复你的邮件。

跟我联系:https://www.linkedin.com/in/fareed-khan-dev/,跟我在领英上连接。

请通过以下邮箱联系我: fareedhassankhan12@gmail.com

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消