Quantcast
Channel: Replacing tf.placeholder and feed_dict with tf.data API - Stack Overflow
Viewing all articles
Browse latest Browse all 3

Replacing tf.placeholder and feed_dict with tf.data API

$
0
0

I have an existing TensorFlow model which used a tf.placeholder for the model input and the feed_dict parameter of tf.Session().run to feed in data. Previously the entire dataset was read into memory and passed in this way.

I want to use a much larger dataset and take advantage of the performance improvements of the tf.data API. I've defined a tf.data.TextLineDataset and one-shot iterator from it, but I'm having a hard time figuring out how to get the data into the model to train it.

At first I tried to just define the feed_dict as a dictionary from the placeholder to iterator.get_next(), but that gave me an error saying the value of a feed cannot be a tf.Tensor object. More digging led me to understand that this is because the object returned by iterator.get_next() is already part of the graph, unlike what you would feed into feed_dict -- and that I shouldn't be trying to use feed_dict at all anyway for performance reasons.

So now I've gotten rid of the input tf.placeholder and replaced it with a parameter to the constructor of the class that defines my model; when constructing the model in my training code, I pass the output of iterator.get_next() to that parameter. This already seems a bit clunky because it breaks separation between the definition of the model and the datasets/training procedure. And I'm now getting an error saying that the Tensor representing (I believe) my model's input must be from the same graph as the Tensor from iterator.get_next().

Am I on the right track with this approach and just doing something wrong with how I set up the graph and the session, or something like that? (The datasets and model are both initialized outside of a session, and the error occurs before I attempt to create one.)

Or am I totally off base with this and need to do something different like use the Estimator API and define everything in an input function?

Here is some code demonstrating a minimal example:

import tensorflow as tf
import numpy as np

class Network:
    def __init__(self, x_in, input_size):
        self.input_size = input_size
        # self.x_in = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size))  # Original
        self.x_in = x_in
        self.output_size = 3

        tf.reset_default_graph()  # This turned out to be the problem

        self.layer = tf.layers.dense(self.x_in, self.output_size, activation=tf.nn.relu)
        self.loss = tf.reduce_sum(tf.square(self.layer - tf.constant(0, dtype=tf.float32, shape=[self.output_size])))

data_array = np.random.standard_normal([4, 10]).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data_array).batch(2)

model = Network(x_in=dataset.make_one_shot_iterator().get_next(), input_size=dataset.output_shapes[-1])

Viewing all articles
Browse latest Browse all 3

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>