我写了一个 TF 数据管道,看起来像这样(TF 2.6):
def parse(img):
image = tf.image.decode_png(img, channels=3)
image = tf.reshape(image, IMG_SHAPE)
image = tf.cast(image, TARGET_DTYPE)
return image
def decode_batch(serialized_example, is_test=False):
feature_dict = {
'image': tf.io.FixedLenFeature(shape=[], dtype=tf.string, default_value=''),
}
if not is_test:
feature_dict["some_text"] = tf.io.FixedLenFeature(shape=[MAX_LEN], dtype=tf.int64, default_value=[0]*MAX_LEN)
else:
feature_dict["image_id"] = tf.io.FixedLenFeature(shape=[], dtype=tf.string, default_value='')
features = tf.io.parse_example(tf.reshape(serialized_example, [BATCH_SIZE_OVERALL]), features=feature_dict)
images = tf.map_fn(parse, features['image'], parallel_iterations=4, fn_output_signature=TARGET_DTYPE)
if is_test:
image_ids = features["image_id"]
return images, image_ids
else:
targets = tf.cast(features["some_text"], tf.uint8)
return images, targets
def get_dataset(filenames, is_test):
opts = tf.data.Options()
opts.experimental_deterministic = False
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.with_options(opts)
dataset = dataset.interleave(lambda x:
tf.data.TFRecordDataset(x),
cycle_length=4,
num_parallel_calls=4,
)
dataset = dataset.batch(BATCH_SIZE_OVERALL, num_parallel_calls=4, drop_remainder=True)
if not is_test:
dataset = dataset.repeat()
dataset = dataset.shuffle(BATCH_SIZE_OVERALL*6)
dataset = dataset.map(lambda y: decode_batch(y, is_test), num_parallel_calls=4)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
train_ds = get_dataset(TRAIN_TFREC_PATHS, False)
正如您从代码中看到的那样,我使用了 TF 指南中的大部分技巧来正确构建tf.data
管道。我遇到的问题如下:开始训练时,代码没有使用全部4个核心,而是只使用了1个(有时会使用更多核心,但似乎是由train_dist_ds.get_next()
下面代码中的调用引起的)。此外,GPU 几乎没有被使用。分析器说问题出在预处理中,并且tf_data_bottleneck_analysis
表明问题出在ParallelBatch
(尽管一旦他指出ParallelMap
,这似乎是正确的,但这本身并不能说明太多 - 无论如何核心仍然没有得到充分利用)。带有分析器的训练函数如下所示:
def fit_profile(train_ds, val_ds, stop_after_steps):
tf.profiler.experimental.start('logdir')
stat_logger.current_step = 0
train_dist_ds = iter(train_ds)
while True:
stat_logger.batch_start_time = time.time()
stat_logger.current_step += 1
print(f'current step: {stat_logger.current_step}')
with tf.profiler.experimental.Trace('train', step_num=stat_logger.current_step, _r=1):
image_batch, some_text_batch = train_dist_ds.get_next()
train_step(image_batch, some_text_batch)
if stat_logger.current_step == stop_after_steps:
break
tf.profiler.experimental.stop()
正如你所看到的,我没有接触数据集,我没有将它放入任何策略中,它在train_step
(当然是包裹在 中@tf.function
)。问题:有没有办法以某种方式调试图中的计算以进行tf.data
操作?特别是tf.data
在预处理中对每个 API 函数的调用级别——这样我就可以了解要优化什么。只使用一个核心的原因可能是什么?
到目前为止我已经尝试过:
- 将所有可自动调整的参数设置为
tf.data.AUTOTUNE
- 无效; - 仅对数据集对象进行迭代——在这种情况下使用了所有内核,从中我得出结论,问题出在图形执行级别——并行性并未全局关闭;
- 关闭分析器 - 没有效果;
- 降低通话量
parallel_iterations
-map_fn
无效; - 很多奇怪的设置
num_parallel_calls
- 没有影响到它似乎真的无关紧要。