3
votes

I'm coding a Pix2Pix network, with my own load_input/real_image function, and I'm currently creating the dataset with tf.data.Dataset. The problem is that my dataset has the wrong shape:

I've tried applying a few tf.data.experimemtal functions, none of them work as I want.

    raw_data = [load_image_train(category)
                for category in SELECTED_CATEGORIES
                for _ in range(min(MAX_SAMPLES_PER_CATEGORY, category[1]))]
    train_dataset = tf.data.Dataset.from_tensor_slices(raw_data)
    train_dataset = train_dataset.cache().shuffle(BUFFER_SIZE)
    train_dataset = train_dataset.batch(1)

I have : < BatchDataset shapes: (None, 2, 256, 256, 3), types: tf.float32>

I want : < DatasetV1Adapter shapes: ((None, 256, 256, 3), (None, 256, 256, 3)), types: (tf.float32, tf.float32)>

1

1 Answers

1
votes

You can do it in two ways.

Option 1 (Preferred)

raw_data1, raw_data2 = tf.unstack(raw_data, axis=1)
train_dataset = tf.data.Dataset.from_tensor_slices((raw_data1, raw_data2))

Option 2

def map_fn(data):
    return tf.unstack(data, axis=0)

train_dataset = tf.data.Dataset.from_tensor_slices(raw_data)
train_dataset = train_dataset.map(map_fn)