mardi 2 novembre 2021

Error in custom augmentation layer in Tensorflow 2 while loop

In Tensorflow 2, I want to mask a spectrogram with 0s at random locations. I want to use this in a custom augmentation layer. I’m doing something wrong because I get a mistake because of the different shapes. All kinds of help come in handy. I haven't been able to solve this problem for 3 days because I'm constantly bumping into obstacles. Thanks for the help!

The code:

def spectogram_mask(spectrogram,freq_times,F,time_times,T):

  def body(i,F,spectrogram):
    s_f,s_t=spectrogram.shape
    f=tf.random.uniform(shape=[], minval=0, maxval=F, dtype=tf.int32)
    f0=tf.random.uniform(shape=[], minval=0, maxval=s_f-f, dtype=tf.int32)
    spectrogram=tf.reshape(tf.concat([spectrogram[:f0,:],tf.zeros((f,s_t)),spectrogram[f0+f:,:]],0),(5511, 101))
    return tf.add(i, 1), F, spectrogram

  def body2(j,T,spectrogram):
    s_f,s_t=spectrogram.shape
    t=tf.random.uniform(shape=[], minval=0, maxval=T, dtype=tf.int32)
    t0=tf.random.uniform(shape=[], minval=0, maxval=s_t-t, dtype=tf.int32)
    spectrogram=tf.reshape(tf.concat([spectrogram[:,:t0],tf.zeros((s_f,t)),spectrogram[:,t0+t:]],1),(5511, 101))
    return tf.add(j, 1),T,spectrogram

  def condition(i,F,spectrogram):
    return tf.less(i, freq_times)

  def condition2(j,T,spectrogram):
    return tf.less(j,time_times)

  i = tf.constant(0)
  i,F,spectrogram = tf.while_loop(condition, body, [i,F,spectrogram], shape_invariants=[i.get_shape(), None, spectrogram.get_shape()], parallel_iterations=1)
  
  j = tf.constant(0)
  j,T,spectrogram = tf.while_loop(condition2, body2, [j,T,spectrogram], shape_invariants=[j.get_shape(), None, spectrogram.get_shape()], parallel_iterations=1)
    
  return tf.reshape(spectrogram,(5511, 101))
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers.experimental.preprocessing import PreprocessingLayer
from tensorflow.python.keras import backend
from tensorflow.python.keras.utils import control_flow_util
#from tensorflow.python.ops import control_flow_ops
from tensorflow.random import Generator
import numpy as np
from keras.preprocessing.image import img_to_array, array_to_img

class AugmentLayer(PreprocessingLayer):

  def __init__(self,freq_times,F,time_times,T,seed=None,**kwargs):
    super().__init__(**kwargs)
    self.freq_times=freq_times
    self.F=F
    self.time_times=time_times
    self.T=T
    self.seed=seed
        
  def call(self,inputs,training):
    @tf.function
    def augmentation():
      augmented = inputs
      #self._rng.make_seeds()[:, 0])
      l=len(inputs.shape)
      if l==3:
        #hogy működjön több példára egyszerre
        #augmented=[spectogram_mask(i.numpy(),self.freq_times,self.F,self.time_times,self.T) for i in inputs]
        augmented=tf.vectorized_map(lambda x: spectogram_mask(x,self.freq_times,self.F,self.time_times,self.T),inputs)
      else:
        #hogy működjön egyetlen példára is
        augmented=tf.convert_to_tensor(spectogram_mask(inputs.numpy(),self.freq_times,self.F,self.time_times,self.T))

      return augmented

    if training is None:
      training = backend.learning_phase()

    output = control_flow_util.smart_cond(training, augmentation,
                                          lambda: inputs)
    output.set_shape(inputs.shape)
    return output

  def compute_output_shape(self, input_shape):
    return input_shape

  def get_config(self):
    config = {
      'freq_times': self.freq_times,
      'F': self.F,
      'time_times': self.time_times,
      'T': self.T,
      'seed': self.seed
    }
    
    return dict(list(base_config.items()) + list(config.items()))

The error:

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  PartialTensorShape: Incompatible shapes during merge: [18,101] vs. [10,101]
     [[]]
     [[ConstantFolding/assert_greater_equal/Assert/AssertGuard/switch_pred/_5_const_false/_55]]
  (1) Invalid argument:  PartialTensorShape: Incompatible shapes during merge: [18,101] vs. [10,101]
     [[]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_51113]

Function call stack:
train_function -> train_function



Aucun commentaire:

Enregistrer un commentaire