使用Python实现深度学习模型:生成对抗网络(GAN)

生成对抗网络(Generative Adversarial Network,GAN)是一种无监督学习的深度学习模型,由Ian Goodfellow等人在2014年提出。GAN包含两个相互竞争的神经网络:生成器(Generator)和判别器(Discriminator)。生成器试图生成看起来像真实数据的假数据,而判别器则试图区分真实数据和生成数据。通过这种对抗过程,生成器能够生成非常逼真的数据。本教程将详细介绍如何使用Python和PyTorch库实现一个简单的GAN,并展示其在MNIST数据集上的应用。

什么是生成对抗网络(GAN)?

生成对抗网络由两个部分组成:

  • 生成器(Generator):接受随机噪声作为输入,并生成假数据。
  • 判别器(Discriminator):接受数据(真实或生成)作为输入,并预测该数据是真实的还是生成的。

GAN的训练过程是生成器和判别器之间的一个博弈:生成器试图欺骗判别器,而判别器试图提高识别真实数据和假数据的能力。

实现步骤

步骤 1:导入所需库

首先,我们需要导入所需的Python库:PyTorch用于构建和训练GAN模型,Matplotlib用于数据的可视化。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

步骤 2:准备数据

我们将使用MNIST数据集作为示例数据。MNIST是一个手写数字数据集,常用于图像处理的基准测试。

# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))  # 将图像归一化到[-1, 1]范围内
])

# 下载并加载训练数据
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

步骤 3:定义生成器和判别器模型

我们定义一个简单的生成器和判别器模型。

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)


class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

# 定义模型参数
input_size = 100  # 噪声向量的维度
hidden_size = 256
image_size = 28 * 28  # MNIST图像的维度

# 创建生成器和判别器实例
G = Generator(input_size, hidden_size, image_size)
D = Discriminator(image_size, hidden_size, 1)

步骤 4:定义损失函数和优化器

我们选择二元交叉熵(Binary Cross Entropy,BCE)损失函数作为模型训练的损失函数,并使用Adam优化器进行优化。

criterion = nn.BCELoss()
lr = 0.0002

# 创建生成器和判别器的优化器
optimizer_G = optim.Adam(G.parameters(), lr=lr)
optimizer_D = optim.Adam(D.parameters(), lr=lr)

步骤 5:训练模型

我们使用定义的生成器和判别器模型对MNIST数据集进行训练。

num_epochs = 50

for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        batch_size = images.size(0)
        images = images.view(batch_size, -1)

        # 创建标签
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        # 训练判别器
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs

        z = torch.randn(batch_size, input_size)
        fake_images = G(z)
        outputs = D(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        d_loss = d_loss_real + d_loss_fake
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        z = torch.randn(batch_size, input_size)
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, '
          f'D(x): {real_score.mean().item():.4f}, D(G(z)): {fake_score.mean().item():.4f}')

步骤 6:可视化生成结果

训练完成后,我们可以使用训练好的生成器模型生成一些新的手写数字图像,并进行可视化。

# 生成一些新图像
z = torch.randn(64, input_size)
fake_images = G(z)
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)

# 可视化生成的图像
grid = torchvision.utils.make_grid(fake_images, nrow=8, normalize=True)
plt.imshow(grid.permute(1, 2, 0).detach().numpy())
plt.title('Generated Images')
plt.show()

总结

通过本教程,你学会了如何使用Python和PyTorch库实现一个简单的生成对抗网络(GAN),并在MNIST数据集上进行训练和生成图像。生成对抗网络是一种强大的生成模型,能够生成逼真的图像数据,广泛应用于图像生成、数据增强、风格转换等领域。希望本教程能够帮助你理解GAN的基本原理和实现方法,并启发你在实际应用中使用GAN解决生成任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/631931.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AC/DC电源模块在工业自动化领域的应用探析

BOSHIDA AC/DC电源模块在工业自动化领域的应用探析 AC/DC电源模块是一种将交流电转换为直流电的电力转换设备,在工业自动化领域具有广泛的应用。本文将从稳定性、效率和可靠性三个方面对AC/DC电源模块在工业自动化领域的应用进行探析。 首先,AC/DC电源模…

【SQL】SQL常见面试题总结(1)

目录 1、检索数据1.1、从 Customers 表中检索所有的 ID1.2、检索并列出已订购产品的清单1.2、检索所有列 2、排序检索数据2.1、检索顾客名称并且排序2.2、对顾客 ID 和日期排序2.3、按照数量和价格排序2.4、检查 SQL 语句 3、过滤数据3.1、返回固定价格的产品3.2、返回产品并且…

React 第三十七章 Scheduler 最小堆算法

在 Scheduler 中&#xff0c;使用最小堆的数据结构在对任务进行排序。 // 两个任务队列 var taskQueue: Array<Task> []; var timerQueue: Array<Task> [];push(timerQueue, newTask); // 像数组中推入一个任务 pop(timerQueue); // 从数组中弹出一个任务 time…

【漏洞复现】用友 NC portal-registerServlet JNDI注入漏洞

0x01 产品简介 用友NC是用友网络科技股份有限公司开发的一款大型企业数字化平台。它主要用于企业的财务核算、成本管理、资金管理、固定资产管理、应收应付管理等方面的工作,致力于帮助企业建立科学的财务管理体系,提高财务核算的准确性和效率。 0x02 漏洞概述 用友NC存在…

Elasticsearch 在滴滴的应用与实践

滴滴 Elasticsearch 简介 简介 Elasticsearch 是一个基于 Lucene 构建的开源、分布式、RESTful 接口的全文搜索引擎&#xff0c;其每个字段均可被索引&#xff0c;且能够横向扩展至数以百计的服务器存储以及处理 TB 级的数据&#xff0c;其可以在极短的时间内存储、搜索和分析大…

登录接口取到token,加到请求头中,通过服务器验证#Vue3

登录接口取到token&#xff0c;加到请求头中&#xff0c;通过服务器验证#Vue3 Token验证的基本流程 1.服务端收到请求&#xff0c;去验证用户名与密码 2.验证成功后&#xff0c;服务端会签发一个 Token&#xff0c;再把这个 Token 发送给客户端 3.客户端收到 Token 以后可以把它…

Linux文件系统详解

&#x1f30e;Linux文件系统 文章目录&#xff1a; Linux文件系统 简单认识磁盘 文件系统       磁盘线性结构抽象       文件系统存储方法 inode Table         inode Bitmap         Data Block         Block Bitmap         …

【漏洞复现】方正全媒体采编系统密码泄露漏洞

0x01 产品简介 方正全媒体新闻采编系统是一个面向媒体深度融合的技术平台&#xff0c;它以大数据和AI技术为支撑&#xff0c;集成了指挥中心、采集中心、编辑中心、发布中心、绩效考核中心、资料中心等多个功能&#xff0c;全面承载“策采编审发存传评”的融媒体业务流程。 0…

爱吃香蕉的珂珂

题目链接 爱吃香蕉的珂珂 题目描述 注意点 piles.length < h < 10^9如果某堆香蕉少于k根&#xff0c;将吃掉这堆的所有香蕉&#xff0c;然后这一小时内不会再吃更多的香蕉返回可以在 h 小时内吃掉所有香蕉的最小速度 k&#xff08;k 为整数&#xff09; 解答思路 二…

Find My资讯|苹果 iOS 17.5 率先执行跨平台反跟踪器标准

苹果和谷歌公司于 2023 年 5 月宣布推出“检测预期外位置追踪器”&#xff08;Detecting Unwanted Location Trackers&#xff09;行业标准&#xff0c;经过 1 年多的打磨之后&#xff0c;该标准目前已通过 iOS 17.5 部署到 iPhone 上。谷歌也将为运行 Android 6.0 或更高版本的…

【从零开始学架构 架构基础】二 架构设计的复杂度来源:高性能复杂度来源

架构设计的复杂度来源其实就是架构设计要解决的问题&#xff0c;主要有如下几个&#xff1a;高性能、高可用、可扩展、低成本、安全、规模。复杂度的关键&#xff0c;就是新旧技术之间不是完全的替代关系&#xff0c;有交叉&#xff0c;有各自的特点&#xff0c;所以才需要具体…

FestDfs快速安装和数据迁移同步。Ubuntu环境

一&#xff1a;防火墙 ufw status 二&#xff1a;下载 分别是&#xff08;环境依赖&#xff0c;网络模块依赖&#xff0c;安装包&#xff09; git clone https://github.com/happyfish100/libfastcommon.git git clone https://github.com/happyfish100/libserverframe.git …

package-lock.json导致npm install安装nyc出现超时错误

一、背景 前端项目在npm install安装依赖&#xff0c;无法下载组件nyc&#xff0c;详细报错信息&#xff1a; npm ERR! code CERT_HAS_EXPIRED npm ERR! errno CERT_HAS_EXPIRED npm ERR! request to https://registry.npm.taobao.org/nyc/download/nyc-13.3.0.tgz?cache0&a…

析构函数详解

目录 析构函数概念特性对象的销毁顺序 感谢各位大佬对我的支持,如果我的文章对你有用,欢迎点击以下链接 &#x1f412;&#x1f412;&#x1f412; 个人主页 &#x1f978;&#x1f978;&#x1f978; C语言 &#x1f43f;️&#x1f43f;️&#x1f43f;️ C语言例题 &…

开源标注工具LabelMe的使用

开源标注工具LabelMe使用Python实现&#xff0c;并使用Qt作为其图形界面&#xff0c;进行图像多边形标注。源码地址:https://github.com/labelmeai/labelme &#xff0c;最新发布版本为v5.4.1&#xff0c;它遵循GNU通用公共许可证的条款。 1.Features (1).多边形、矩形、圆形、…

Linux下mysql备份

参考文章&#xff1a; Linux实现MySQL数据库数据自动备份&#xff0c;并定期删除以前备份文件-CSDN博客文章浏览阅读7.2k次&#xff0c;点赞7次&#xff0c;收藏29次。引言在学习过程中遇到了一个问题&#xff0c;见图&#xff1a;当我进入服务器的数据库时&#xff0c;原来的…

羊大师:羊奶健康的成长伴侣

羊大师&#xff1a;羊奶健康的成长伴侣 在追求健康生活的当下&#xff0c;越来越多的人开始关注饮食的营养与健康。羊大师发现在众多天然食品中&#xff0c;羊奶以其独特的营养价值和健康益处&#xff0c;逐渐成为了人们的新宠。特别是对于正在成长发育的孩子们来说&#xff0…

客户端Web资源缓存

为了提高Web服务器的性能,其中的一种可以提高Web服务器性能的方法就是采用缓存技术。 1.缓存 1.1.什么是缓存&#xff1f; 如果某个资源的计算耗时或耗资源&#xff0c;则执行一次并存储结果。当有人随后请求该资源时&#xff0c;返回存储的结果&#xff0c;而不是再次计算。…

免费视频格式在线转换网站,推荐这5款!

在数字化时代&#xff0c;视频已成为我们日常生活和工作中不可或缺的一部分。然而&#xff0c;随着各种设备和平台的不断涌现&#xff0c;视频格式繁多&#xff0c;常常会出现不兼容的情况。为了解决这一问题&#xff0c;视频格式在线转换网站应运而生&#xff0c;成为了我们应…

【数据结构】排序(归并排序,计数排序)

一、归并排序 基本思想&#xff1a; 归并排序&#xff08;MERGE-SORT&#xff09;是建立在归并操作上的一种有效的排序算法,该算法是采用分治法&#xff08;Divide and Conquer&#xff09;的一个非常典型的应用。将已有序的子序列合并&#xff0c;得到完全有序的序列&#xf…