TadGAN migration to tensorflow 2.0
1. Deprecation of _Merge
layer
when generating the interpolated signal, we inherent _Merge
from keras.layers.merge import _Merge class RandomWeightedAverage(_Merge): def _merge_function(self, inputs): alpha = K.random_uniform((64, 1, 1)) return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])
which we can move into a subclass of Layer
to produce the same result.
from keras.layers import Layer class RandomWeightedAverage(Layer): def call(self, inputs, **kwargs): alpha = K.random_uniform((64, 1, 1)) return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])
2. Eager execution
The current code does not support tensorflow eager execution. This causes the calculation of the gradient penalty to be None
when findingK.gradients(y_pred, averaged_samples)[0]
, which in consequence causes the remainder calculation to fail.
To follow the new style, we should use GradientTape
in the training step of the model. This will require some fundamental change to how TadGAN is currently compiled and trained.
def train_step(self, X): for critic iterations: # generate random variable with tf.GradientTape() as tape: # calculate loss of critics # update critics' gradients with tf.GradientTape() as tape: # calculate loss of generator/encoder # update generator/encoder gradients
Temporary solution is to disable eager execution
import tensorflow as tf tf.compat.v1.disable_eager_execution()
Update
Eager execution is required to run on GPU. Thus, the temporary solution will only work for CPU. It's better to continue with the original proposal with a lot of code refactoring.
3. cycle graph
While in tensorflow 1, the model still works correctly, it generates the following warning
topological sort failed with message: The graph couldn't be sorted in topological order
this is caused by the cycle in the directed graph, and we should eliminate it. In fact, I believe that this cycle is causing an error in tensorflow 2. If we make all the changes suggested above, we would still encounter an error within critic_z_model
and encoder_generator_model
of the following kind (this one specifically is when using critic_z_model.train_on_batch(..)
.
InvalidArgumentError: Node 'training_2/Adam/gradients/gradients/loss_1/functional_7_2_loss/gradients/
functional_7_2/sequential_3/dropout_4/cond_grad/StatelessIf_grad/StatelessIf': Connecting to invalid
output 1 of source node
loss_1/functional_7_2_loss/gradients/functional_7_2/sequential_3/dropout_4/cond_grad/StatelessIf
which has 1 outputs. Try using tf.compat.v1.experimental.output_all_intermediates(True).
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4