TensorFlow DataSet API leads to graph explosion

I have a set of error data for training.

I am using the dataset API like so:

self._dataset = tf.contrib.data.Dataset.from_tensor_slices((self._images_list, self._labels_list))

self._dataset = self._dataset.map(self.load_image)

self._dataset = self._dataset.batch(batch_size)
self._dataset = self._dataset.shuffle(buffer_size=shuffle_buffer_size)
self._dataset = self._dataset.repeat()

self._iterator = self._dataset.make_one_shot_iterator()

      

If I use a small amount of data for training, then everything is fine. If I use all my data then TensorFlow will crash with this error: ValueError: GraphDef cannot exceed 2GB.

It seems like TensorFlow is trying to load all the data and not only load the data it needs ... not sure ...

Any advice would be great!

Update ... found solution / workaround

according to this post: Densetet Tensorflow API doubles the size of the protobuff log file

I replaced make_one_shot_iterator () with make_initializable_iterator () and of course named the iterator initializer after the session was created:

init = tf.global_variables_initializer()
sess.run(init)
sess.run(train_data._iterator.initializer)

      

But I am leaving the question open as it seems to me that this is a workaround, not a solution ...

+3


source to share





All Articles