I'm trying to write my own data generator for an image segmentation task in order to have more control over the transformations that I can apply. To start, I'm writing a toy generator that outputs a randomly cropped image using tf.image.random_crop:
import tensorflow.keras.preprocessing.image as k_img
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
class DataGen:
seed=1
def __init__(self,seed=1):
self.seed = seed
def flow(self):
#tf.random.set_seed(self.seed)
img = np.asarray(k_img.load_img('data/train/images/all/image.png',color_mode='grayscale'))
while(True):
img_crop = tf.image.random_crop(img,(32,32),seed=self.seed)
yield img_crop
The use case looks like this:
gen2 = DataGen(1234).flow()
plt.imshow(next(gen))
plt.imshow(next(gen))
plt.imshow(next(gen2))
plt.imshow(next(gen2))
The desired behavior is for image 1 & 3 to be the same and for image 2 & 4 to be the same. However, setting the seeds within tf.image.random_crop in this context does not do this--nothing is repeated. Setting tf.random.set_seed (commented out) does enforce a repeatable sequence, but it is global rather than local.
How do I achieve the desired behavior?
Aucun commentaire:
Enregistrer un commentaire