博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
PyTorch快速入门教程六(使用LSTM做图片分类)
阅读量:5894 次
发布时间:2019-06-19

本文共 1771 字,大约阅读时间需要 5 分钟。

  hot3.png

对于LSTM,我们要处理的数据是一个序列数据,对于图片而言,我们如何将其转换成序列数据呢?图片的大小是28x28,所以我们可以将其看成长度为28的序列,序列中的每个数据的维度是28,这样我们就可以将其变成一个序列数据了。

model

class Rnn(nn.Module):    def __init__(self, in_dim, hidden_dim, n_layer, n_class):        super(Rnn, self).__init__()        self.n_layer = n_layer        self.hidden_dim = hidden_dim        self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer,                            batch_first=True)        self.classifier = nn.Linear(hidden_dim, n_class)    def forward(self, x):        # h0 = Variable(torch.zeros(self.n_layer, x.size(1),                                #   self.hidden_dim)).cuda()        # c0 = Variable(torch.zeros(self.n_layer, x.size(1),                                #   self.hidden_dim)).cuda()        out, _ = self.lstm(x)        out = out[:, -1, :]        out = self.classifier(out)        return outmodel = Rnn(28, 128, 2, 10)  # 图片大小是28x28use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速if use_gpu:    model = model.cuda()# 定义loss和optimizercriterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=learning_rate)

这里我们定义了一个LSTM模型,我们需要传入的参数是输入数据的维数28,LSTM输出的维数128,LSTM网络层数2层以及输出的类数10。

在网络定义里面首先需要定义LSTM,而长度为28的序列传入LSTM之后输出的也是长度为28,而输入的维数是28,输出的维数由我们定义为128,最后我们只取输出的最后一个部分传入分类器求出分类概率。

out = out[:, -1, :]通过这种方式,out中的三个维度分别表示batch_size,序列长度和数据维度,所以中间的序列长度取-1,表示取序列中的最后一个数据,这个数据维度为128,再通过分类器,输出10个结果表示每种结果的概率。

另外上面注释掉的部分就是初始的h_0和c_0,这里可以自己定义,如果不定义,默认传入0,也可以根据自己的要求传入自己定义的h_0和c_0。

开始训练

把训练过程的batch_size设置为100,learning_rate设置为0.01,训练20次,最后得到的结果如下

使用LSTM做图片分类

可以发现对于简单的图像分类RNN也能得到一个较好的结果,虽然CNN更多的用在图像领域而RNN更多的用在自然语言处理中。RNNCNN彼此优缺点可以自行百度。

在这里,我整理发布了Pytorch中文文档,方便大家查询使用,同时也准备了中文论坛,欢迎大家学习交流!

Pytorch中文文档已经发布,完美翻译,更加方便大家浏览:

Pytorch中文网:https://ptorch.com/

Pytorch中文文档:https://ptorch.com/docs/1/

转载于:https://my.oschina.net/earnp/blog/1113895

你可能感兴趣的文章
java中如何实现类似goto的作法
查看>>
海归千千万 为何再无钱学森
查看>>
vue2.0 仿手机新闻站(六)详情页制作
查看>>
FreeRTOS的内存管理
查看>>
JSP----九大内置对象
查看>>
The Z-Index CSS Property: A Comprehensive Look | Smashing Coding
查看>>
Java中HashMap详解
查看>>
Office版本差别引发的语法问题
查看>>
Apache——访问控制
查看>>
web前端(10)—— 浮动,清除默认样式
查看>>
ggplot2 aes函数map到data笔记
查看>>
3450: Tyvj1952 Easy
查看>>
delphi基本语法
查看>>
java中的Static class
查看>>
删除重复节点
查看>>
.net请求Webservice简单实现天气预报功能
查看>>
Loj #3056. 「HNOI2019」多边形
查看>>
【3】数据库的表设计和初始化
查看>>
Django rest framework的基本用法
查看>>
正则表达式匹配非需要匹配的字符串(标题自己都绕晕了)
查看>>