我在用pytorch读取mnist数据集时,采用了两种方法:官方下载和读取本地制作好的数据集,现在读出来的图片的torch.szie()大小不同,分别是torch.Size([1, 28, 28])和torch.Size([3, 28, 28]),请问有什么办法可以把3,28,28变成1,28,28,谢谢!
train_dataset1 = datasets.MNIST(
root='./data', train=True, transform=transforms.ToTensor(), download=True)
train\_loader1 = DataLoader\(train\_dataset\, batch\_size=batch\_size\, shuffle=True\)
from torchvision.datasets import ImageFolder
batch_size = 128
path='D:/work/'
train_dataset2 = ImageFolder(path,transform=transforms.ToTensor())
train\_loader2 = DataLoader\(train\_dataset\, batch\_size=batch\_size\, shuffle=False\)\
print(train_dataset1[0][0].size())
print(train_dataset2[0][0].size())
out:
torch.Size([1, 28, 28])
torch.Size([3, 28, 28])