Tensorflow: how to recover the initial v3 pretrained network weights after extending the graph with a new final layer

I have this network model: pre-provisioning v3. https://storage.googleapis.com/openimages/2016_08/model_2016_08.tar.gz

I want to expand it with a new layer.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os.path

import tensorflow as tf
from tensorflow.contrib import layers
from tensorflow.contrib.slim.python.slim.nets import inception
from tensorflow.python.ops import array_ops
from tensorflow.python.training import saver as tf_saver

slim = tf.contrib.slim
FLAGS = None


def PreprocessImage(image, central_fraction=0.875):
    """Load and preprocess an image.

    Args:
      image: a tf.string tensor with an JPEG-encoded image.
      central_fraction: do a central crop with the specified
        fraction of image covered.
    Returns:
      An ops.Tensor that produces the preprocessed image.
    """

    # Decode Jpeg data and convert to float.
    image = tf.cast(tf.image.decode_jpeg(image, channels=3), tf.float32)

    image = tf.image.central_crop(image, central_fraction=central_fraction)
    # Make into a 4D tensor by setting a 'batch size' of 1.
    image = tf.expand_dims(image, [0])
    image = tf.image.resize_bilinear(image,
                                     [FLAGS.image_size, FLAGS.image_size],
                                     align_corners=False)

    # Center the image about 128.0 (which is done during training) and normalize.
    image = tf.multiply(image, 1.0 / 127.5)
    return tf.subtract(image, 1.0)


def main(args):
    if not os.path.exists(FLAGS.checkpoint):
        tf.logging.fatal(
            'Checkpoint %s does not exist. Have you download it? See tools/download_data.sh',
            FLAGS.checkpoint)
    g = tf.Graph()
    with g.as_default():
        input_image = tf.placeholder(tf.string)
        processed_image = PreprocessImage(input_image)

        with slim.arg_scope(inception.inception_v3_arg_scope()):
            logits, end_points = inception.inception_v3(
                processed_image, num_classes=FLAGS.num_classes, is_training=False)

        predictions = end_points['multi_predictions'] = tf.nn.sigmoid(
            logits, name='multi_predictions')

        sess = tf.Session()

        saver = tf_saver.Saver()

        saver.restore(sess, FLAGS.checkpoint)

        logits_2 = layers.conv2d(
            end_points['PreLogits'],
            FLAGS.num_classes, [1, 1],
            activation_fn=None,
            normalizer_fn=None,
            scope='Conv2d_final_1x1')

        logits_2 = array_ops.squeeze(logits_2, [1, 2], name='SpatialSqueeze_2')

        predictions_2 = end_points['multi_predictions_2'] = tf.nn.sigmoid(logits_2, name='multi_predictions_2')

        sess.run(tf.global_variables_initializer())


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', type=str, default='2016_08/model.ckpt',
                        help='Checkpoint to run inference on.')
    parser.add_argument('--dict', type=str, default='dict.csv',
                        help='Path to a dict.csv that translates from mid to a display name.')
    parser.add_argument('--image_size', type=int, default=299,
                        help='Image size to run inference on.')
    parser.add_argument('--num_classes', type=int, default=6012,
                        help='Number of output classes.')
    parser.add_argument('--image_path', default='test_set/0a9ed4def08fe6d1')
    FLAGS = parser.parse_args()
    tf.app.run()

      

how can I restore only the initial v3 pre-trained net weights saved in the model.ckpt file? how can i initialize a new conv2d layer?

+3


source to share


1 answer


I solved the problem !!

You need to call saver.restore (sess, FLAGS.checkpoint)

after initializing the network with sess.run (tf.global_variables_initializer ())

.

Important : saver = tf_saver.Saver ()

must be created to add new layers to the graph.



Thus, when executed saver.restore(sess, FLAGS.checkpoint)

, it only knows the compute graph before creating new layers.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os.path

import tensorflow as tf
from tensorflow.contrib import layers
from tensorflow.contrib.slim.python.slim.nets import inception
from tensorflow.python.ops import array_ops
from tensorflow.python.training import saver as tf_saver

slim = tf.contrib.slim
FLAGS = None


def PreprocessImage(image, central_fraction=0.875):
    """Load and preprocess an image.

    Args:
      image: a tf.string tensor with an JPEG-encoded image.
      central_fraction: do a central crop with the specified
        fraction of image covered.
    Returns:
      An ops.Tensor that produces the preprocessed image.
    """

    # Decode Jpeg data and convert to float.
    image = tf.cast(tf.image.decode_jpeg(image, channels=3), tf.float32)

    image = tf.image.central_crop(image, central_fraction=central_fraction)
    # Make into a 4D tensor by setting a 'batch size' of 1.
    image = tf.expand_dims(image, [0])
    image = tf.image.resize_bilinear(image,
                                     [FLAGS.image_size, FLAGS.image_size],
                                     align_corners=False)

    # Center the image about 128.0 (which is done during training) and normalize.
    image = tf.multiply(image, 1.0 / 127.5)
    return tf.subtract(image, 1.0)


def main(args):
    if not os.path.exists(FLAGS.checkpoint):
        tf.logging.fatal(
            'Checkpoint %s does not exist. Have you download it? See tools/download_data.sh',
            FLAGS.checkpoint)
    g = tf.Graph()
    with g.as_default():
        input_image = tf.placeholder(tf.string)
        processed_image = PreprocessImage(input_image)

        with slim.arg_scope(inception.inception_v3_arg_scope()):
            logits, end_points = inception.inception_v3(
                processed_image, num_classes=FLAGS.num_classes, is_training=False)

        predictions = end_points['multi_predictions'] = tf.nn.sigmoid(
            logits, name='multi_predictions')

        sess = tf.Session()

        saver = tf_saver.Saver()

        logits_2 = layers.conv2d(
            end_points['PreLogits'],
            FLAGS.num_classes, [1, 1],
            activation_fn=None,
            normalizer_fn=None,
            scope='Conv2d_final_1x1')

        logits_2 = array_ops.squeeze(logits_2, [1, 2], name='SpatialSqueeze_2')

        predictions_2 = end_points['multi_predictions_2'] = tf.nn.sigmoid(logits_2, name='multi_predictions_2')

        sess.run(tf.global_variables_initializer())

        saver.restore(sess, FLAGS.checkpoint)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', type=str, default='2016_08/model.ckpt',
                        help='Checkpoint to run inference on.')
    parser.add_argument('--dict', type=str, default='dict.csv',
                        help='Path to a dict.csv that translates from mid to a display name.')
    parser.add_argument('--image_size', type=int, default=299,
                        help='Image size to run inference on.')
    parser.add_argument('--num_classes', type=int, default=6012,
                        help='Number of output classes.')
    parser.add_argument('--image_path', default='test_set/0a9ed4def08fe6d1')
    FLAGS = parser.parse_args()
    tf.app.run()

      

+1


source







All Articles