Commit e1b9c61b authored by Marcel Rieger's avatar Marcel Rieger

Minor changes for TF2.

parent 71d1c382
Pipeline #118283 passed with stages
in 34 seconds
......@@ -16,7 +16,7 @@ unittests_py2:
stage: test
tags:
- docker
image: tensorflow/tensorflow:1.13.1
image: tensorflow/tensorflow:2.0.0a0
script:
- python -m unittest test
......@@ -24,6 +24,6 @@ unittests_py3:
stage: test
tags:
- docker
image: tensorflow/tensorflow:1.13.1-py3
image: tensorflow/tensorflow:2.0.0a0-py3
script:
- python -m unittest test
......@@ -119,6 +119,9 @@ class LBN(object):
else:
raise ValueError("invalid batch_norm, should be bool or list/tuple of two bools")
# the keras batch normalization layer
self.batch_norm = None
# particle weights and settings
self.particle_weights = particle_weights
self.abs_particle_weights = abs_particle_weights
......@@ -211,7 +214,7 @@ class LBN(object):
if self.features is None:
return None
return self.features.shape[-1].value
return int(self.features.shape[-1])
def register_feature(self, func=None, **kwargs):
"""
......@@ -301,8 +304,8 @@ class LBN(object):
self.inputs = inputs
# infer sizes
self.n_in = self.inputs.shape[1].value
self.n_dim = self.inputs.shape[2].value
self.n_in = int(self.inputs.shape[1])
self.n_dim = int(self.inputs.shape[2])
if self.n_dim != 4:
raise Exception("input dimension must be 4 to represent 4-vectors")
......@@ -510,16 +513,14 @@ class LBN(object):
def build_norm(self):
"""
Applies simple batch normalization with floating averages to the output features using
``tf.layers.batch_normalization``. Make sure to also run the operation returned by
``tf.get_collection(tf.GraphKeys.UPDATE_OPS)`` during each train step.
``tf.keras.layers.BatchNormalization``.
"""
self._norm_features = tf.layers.batch_normalization(
self.features,
self.batch_norm = tf.keras.layers.BatchNormalization(
axis=1,
training=self.is_training,
center=self.batch_norm_center,
scale=self.batch_norm_scale,
)
self._norm_features = self.batch_norm(self.features, training=self.is_training)
class FeatureFactoryBase(object):
......
......@@ -15,9 +15,6 @@ import tensorflow as tf
from lbn import LBN, FeatureFactory
# enable eager execution
tf.enable_eager_execution()
class TestCase(unittest.TestCase):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment