我也将您的代码下载到了我的本地机器和数据集。必须进行一些调整才能使其在本地运行。我相信模型efficientnet_v2_imagenet1k_b0 与较新的高效网络模型不同,因为此版本确实需要在0 和1 之间缩放像素级别。我运行了模型,无论是否重新缩放,它只有在像素被重新缩放时才能正常工作。下面是我用来测试模型是否正确预测从互联网下载的图像的代码。它按预期工作。
import cv2
class_dict=train_generator.class_indices
print (class_dict)
rev_dict={}
for key, value in class_dict.items():
rev_dict[value]=key
print (rev_dict)
fpath=r'C:\Temp\rps\1.jpg' # an image downloaded from internet that should be paper class
img=plt.imread(fpath)
print (img.shape)
img=cv2.resize(img, (224,224)) # resize to 224 X 224 to be same size as model was trained on
print (img.shape)
plt.imshow(img)
img=img/255.0 # rescale as was done with training images
img=np.expand_dims(img,axis=0)
print(img.shape)
p=model.predict(img)
print (p)
index=np.argmax(p)
print (index)
klass=rev_dict[index]
prob=p[0][index]* 100
print (f'image is of class {klass}, with probability of {prob:6.2f}')
结果是
{'paper': 0, 'rock': 1, 'scissors': 2}
{0: 'paper', 1: 'rock', 2: 'scissors'}
(300, 300, 3)
(224, 224, 3)
(1, 224, 224, 3)
[[9.9902594e-01 5.5121275e-04 4.2284720e-04]]
0
image is of class paper, with probability of 99.90
你的代码中有这个
uploaded = files.upload()
len_file = len(uploaded.keys())
这没有运行,因为未定义文件,因此无法找到导致您的错误分类问题的原因。请记住在 flow_from_directory 中,如果您不指定颜色模式,则默认为 rgb。因此,即使训练图像是 4 通道 PNG,实际模型也是在 3 通道上训练的。因此,请确保您要预测的图像是 3 个通道。