1

我使用tf.data API 创建了管道,用于读取图像数据集。我有一个高分辨率的大数据集。但是,每次尝试读取所有数据集时,计算机都会因为代码使用所有 RAM 而崩溃。我用大约 1280 张图像测试了代码,它没有任何错误。但是当我使用所有数据集时,模型就会崩溃。所以,我想知道是否有办法让tf.data读取前面的一个或两个批次不超过这个。

这是我用来创建管道的代码:

    def decode_img(self, img):
        img = tf.image.convert_image_dtype(img, tf.float32, saturate=False)
        img = tf.image.resize(img, size=self.input_dim, antialias=False, name=None)
        return img

    def get_label(self, label):
        y = np.zeros(self.n_class, dtype=np.float32)
        y[label] = 1
        return y

    def process_path(self, file_path, label):
        label = self.get_label(label)
        img = Image.open(file_path)

        width, height = img.size
        # Setting the points for cropped image
        new_hight = height // 2
        new_width = width // 2
        newsize = (new_width, new_hight)
        img = img.resize(newsize)

        if self.aug_img:
            img = self.policy(img)
        img = self.decode_img(np.array(img, dtype=np.float32))
        return img, label

    def create_pip_line(self):

        def _fixup_shape(images, labels):
            images.set_shape([None, None, 3])
            labels.set_shape([7])  # I have 19 classes
            return images, labels

        tf_ds = tf.data.Dataset.from_tensor_slices((self.df["file_path"].values, self.df["class_num"].values))
        tf_ds = tf_ds.map(lambda img, label: tf.numpy_function(self.process_path,
                                                               [img, label],
                                                               (tf.float32, tf.float32)),
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
        tf_ds = tf_ds.map(_fixup_shape)

        if not self.is_val:
            tf_ds = tf_ds.shuffle(len(self.df), reshuffle_each_iteration=True)
        tf_ds = tf_ds.batch(self.batch_size).repeat(self.epoch_num)
        self.tf_ds = tf_ds.prefetch(tf.data.experimental.AUTOTUNE)
4

1 回答 1

2

我的代码中的主要问题是 Shuffle 功能。这个函数有两个参数,第一个是要洗牌的数据数量,第二个是每个时期的重复数据。但是,我发现将加载到内存中的数据数量取决于此函数。因此,我将所有数据的数量减少到 100,这使得管道加载 100 个图像并将它们打乱,然后再加载 100 个,依此类推。

if not self.is_val:
            tf_ds = tf_ds.shuffle(100, reshuffle_each_iteration=True)
于 2020-12-02T14:57:34.437 回答