pytorch 在sequential中使用view来reshape的例子
时间:2021-01-18 14:44:59|栏目:Python代码|点击: 次
pytorch中view是tensor方法,然而在sequential中包装的是nn.module的子类,
因此需要自己定义一个方法:
import torch.nn as nn class Reshape(nn.Module): def __init__(self, *args): super(Reshape, self).__init__() self.shape = args def forward(self, x): # 如果数据集最后一个batch样本数量小于定义的batch_batch大小,会出现mismatch问题。可以自己修改下,如只传入后面的shape,然后通过x.szie(0),来输入。 return x.view(self.shape)
class Reshape(nn.Module): def __init__(self, *args): super(Reshape, self).__init__() self.shape = args def forward(self, x): return x.view((x.size(0),)+self.shape)
上一篇:python selenium操作cookie的实现
栏 目:Python代码
下一篇:Python中关于Sequence切片的下标问题详解
本文标题:pytorch 在sequential中使用view来reshape的例子
本文地址:http://www.codeinn.net/misctech/47058.html






