Commit c57cdea9 authored by Marcel Rieger's avatar Marcel Rieger

Minor change to test_post_build_attributes unit test.

parent a5d84159
Pipeline #106927 passed with stages
in 34 seconds
......@@ -765,7 +765,7 @@ class FeatureFactory(FeatureFactoryBase):
def tf_non_zero(t, epsilon):
Ensures that all zeros in a tensor are replaced with epsilon.
Ensures that all zeros in a tensor *t* are replaced by *epsilon*.
# use combination of abs and sign instead of a where op
return t + (1 - tf.abs(tf.sign(t))) * epsilon
......@@ -104,9 +104,6 @@ class TestCase(unittest.TestCase):
self.assertIsNotNone(getattr(lbn, attr))
def test_post_build_attributes(self):
lbn = LBN(10, boost_mode=LBN.PAIRS, is_training=True)
lbn(self.vectors_t, features=self.feature_set).numpy()
attrs = [
"particle_weights", "abs_particle_weights", "clip_particle_weights",
"restframe_weights", "abs_restframe_weights", "clip_restframe_weights", "n_in", "n_dim",
......@@ -116,6 +113,12 @@ class TestCase(unittest.TestCase):
"restframes", "Lambda", "boosted_particles", "_raw_features", "_norm_features",
lbn = LBN(10, boost_mode=LBN.PAIRS, is_training=True)
for attr in attrs:
self.assertIn(getattr(lbn, attr), (None, True, False))
lbn(self.vectors_t, features=self.feature_set).numpy()
for attr in attrs:
self.assertIsNotNone(getattr(lbn, attr), None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment