时间:2021-04-14 09:03:32 | 栏目:Python代码 | 点击:次
代码:
import torch class_num = 10 batch_size = 4 label = torch.LongTensor(batch_size, 1).random_() % class_num print(label.size()) one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1) print(one_hot)
输出:
torch.Size([4, 1]) tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
注意:
label的形状必须是[n,1]的,也就是必须是二维的,且第二个维度长度为1,如果是一维度的,则需要升维度,代码如下:
import torch class_num = 10 batch_size = 4 label = torch.LongTensor(batch_size).random_() % class_num print(label.size()) label = torch.unsqueeze(label,dim=1) print(label.size())