我想使用dataset.ImageFolder创建一个图像数据集。
我当前的图像目录结构如下所示:
1:在火车图像中,我的标签包含 00、01 等子文件夹。在每个文件夹中,图像包含与每个标签对应的两位数字
这是我使用的代码,后面是标签没有的输出。与图像匹配
这里的路径
data_dir = "/home/mhamdan/hamdan/MNIST_muldigits/data/double_mnist"
train_dir = data_dir + '/train' # training_set contains training dataset
val_dir = data_dir + '/val' #contains validation dataset
test_dir = data_dir + '/test' #contains test dataset
在此处加载数据
#Load the dataset with Image Folder
trainset = datasets.ImageFolder(train_dir, transform = transformation)
valset = datasets.ImageFolder(val_dir, transform = transformation)
testset = datasets.ImageFolder(test_dir, transform = transformation)
数据加载器
#define data loaders
batch_size = 32
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True,num_workers=2)
val_loader = DataLoader(valset, batch_size=batch_size, shuffle=True,num_workers=2)
test_loader = DataLoader(testset, batch_size=batch_size,num_workers=1)
这是随机训练图像的绘图
examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)
import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title("Ground Truth: {}".format(example_targets[1]))
plt.xticks([])
plt.yticks([])
fig
如您在此处看到的,标签与图像不同 标签与图像不同