2

我想在将图像传递给数据加载器之前对其进行二值化,我创建了一个运行良好的数据集类。但在__getitem__()我想对图像进行阈值处理的方法中:

    def __getitem__(self, idx):
        # Open image, apply transforms and return with label
        img_path = os.path.join(self.dir, self.filelist[filename"])
        image = Image.open(img_path)
        label = self.x_data.iloc[idx]["label"]

        # Applying transformation to the image
        if self.transforms is not None:
           image = self.transforms(image)

        # applying threshold here:
        my_threshold = 240
        image = image.point(lambda p: p < my_threshold and 255)
        image = torch.tensor(image)

        return image, label

然后我尝试调用数据集:

    data_transformer = transforms.Compose([
        transforms.Resize((10, 10)),
        transforms.Grayscale()
        //transforms.ToTensor()
    ])

train_set = MyNewDataset(data_path, data_transformer, rows_train)

由于我已经在 PIL 对象上应用了阈值,因此我需要在之后应用转换为张量对象,但由于某种原因它崩溃了。有人可以帮助我吗?

4

1 回答 1

3

为什么在从PIL.Imageto转换后不应用二值化torch.Tensor

class ThresholdTransform(object):
  def __init__(self, thr_255):
    self.thr = thr_255 / 255.  # input threshold for [0..255] gray level, convert to [0..1]

  def __call__(self, x):
    return (x > self.thr).to(x.dtype)  # do not change the data type

一旦你有了这个转换,你只需添加它:

data_transformer = transforms.Compose([
        transforms.Resize((10, 10)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        ThresholdTransform(thr_255=240)
    ])
于 2021-01-31T12:51:42.927 回答