From e1b9c61bcade5f33a4e041db7fb49f16eddeef31 Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Sat, 13 Apr 2019 15:09:52 +0200 Subject: [PATCH] Minor changes for TF2. --- .gitlab-ci.yml | 4 ++-- lbn.py | 17 +++++++++-------- test.py | 3 --- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index db54f5f..be10a58 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -16,7 +16,7 @@ unittests_py2: stage: test tags: - docker - image: tensorflow/tensorflow:1.13.1 + image: tensorflow/tensorflow:2.0.0a0 script: - python -m unittest test @@ -24,6 +24,6 @@ unittests_py3: stage: test tags: - docker - image: tensorflow/tensorflow:1.13.1-py3 + image: tensorflow/tensorflow:2.0.0a0-py3 script: - python -m unittest test diff --git a/lbn.py b/lbn.py index 11b07f4..cfdde97 100644 --- a/lbn.py +++ b/lbn.py @@ -119,6 +119,9 @@ class LBN(object): else: raise ValueError("invalid batch_norm, should be bool or list/tuple of two bools") + # the keras batch normalization layer + self.batch_norm = None + # particle weights and settings self.particle_weights = particle_weights self.abs_particle_weights = abs_particle_weights @@ -211,7 +214,7 @@ class LBN(object): if self.features is None: return None - return self.features.shape[-1].value + return int(self.features.shape[-1]) def register_feature(self, func=None, **kwargs): """ @@ -301,8 +304,8 @@ class LBN(object): self.inputs = inputs # infer sizes - self.n_in = self.inputs.shape[1].value - self.n_dim = self.inputs.shape[2].value + self.n_in = int(self.inputs.shape[1]) + self.n_dim = int(self.inputs.shape[2]) if self.n_dim != 4: raise Exception("input dimension must be 4 to represent 4-vectors") @@ -510,16 +513,14 @@ class LBN(object): def build_norm(self): """ Applies simple batch normalization with floating averages to the output features using - ``tf.layers.batch_normalization``. Make sure to also run the operation returned by - ``tf.get_collection(tf.GraphKeys.UPDATE_OPS)`` during each train step. + ``tf.keras.layers.BatchNormalization``. """ - self._norm_features = tf.layers.batch_normalization( - self.features, + self.batch_norm = tf.keras.layers.BatchNormalization( axis=1, - training=self.is_training, center=self.batch_norm_center, scale=self.batch_norm_scale, ) + self._norm_features = self.batch_norm(self.features, training=self.is_training) class FeatureFactoryBase(object): diff --git a/test.py b/test.py index d40c75b..8de88c8 100644 --- a/test.py +++ b/test.py @@ -15,9 +15,6 @@ import tensorflow as tf from lbn import LBN, FeatureFactory -# enable eager execution -tf.enable_eager_execution() - class TestCase(unittest.TestCase): -- GitLab