Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
T
Tools
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
3pia
CMS Analyses
Tools
Commits
a18a0e81
Commit
a18a0e81
authored
4 years ago
by
jan.middendorf@rwth-aachen.de
Browse files
Options
Downloads
Plain Diff
Merge branch 'master' of git.rwth-aachen.de:3pia/cms_analyses/tools
parents
aa7bd212
bbb0bcc3
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
data.py
+41
-45
41 additions, 45 deletions
data.py
keras.py
+334
-91
334 additions, 91 deletions
keras.py
plotting.py
+122
-0
122 additions, 0 deletions
plotting.py
with
497 additions
and
136 deletions
data.py
+
41
−
45
View file @
a18a0e81
...
@@ -8,7 +8,7 @@ class SKDict(dict):
...
@@ -8,7 +8,7 @@ class SKDict(dict):
@staticmethod
@staticmethod
def
keyify
(
keyish
):
def
keyify
(
keyish
):
if
not
isinstance
(
keyish
,
(
tuple
,
list
,
set
,
frozenset
)):
if
not
isinstance
(
keyish
,
(
tuple
,
list
,
set
,
frozenset
)):
keyish
=
keyish
,
keyish
=
(
keyish
,
)
keyish
=
frozenset
(
keyish
)
keyish
=
frozenset
(
keyish
)
assert
not
any
(
isinstance
(
key
,
set
)
for
key
in
keyish
)
assert
not
any
(
isinstance
(
key
,
set
)
for
key
in
keyish
)
return
keyish
return
keyish
...
@@ -19,7 +19,7 @@ class SKDict(dict):
...
@@ -19,7 +19,7 @@ class SKDict(dict):
def
update
(
self
,
*
args
,
**
kwargs
):
def
update
(
self
,
*
args
,
**
kwargs
):
# assert 0 <= len(args) <= 1
# assert 0 <= len(args) <= 1
args
+=
kwargs
,
args
+=
(
kwargs
,
)
for
arg
in
args
:
for
arg
in
args
:
for
k
,
v
in
arg
.
items
():
for
k
,
v
in
arg
.
items
():
self
[
k
]
=
v
self
[
k
]
=
v
...
@@ -34,11 +34,7 @@ class SKDict(dict):
...
@@ -34,11 +34,7 @@ class SKDict(dict):
key
=
self
.
keyify
(
key
)
key
=
self
.
keyify
(
key
)
if
key
in
self
:
if
key
in
self
:
return
super
(
SKDict
,
self
).
__getitem__
(
key
)
return
super
(
SKDict
,
self
).
__getitem__
(
key
)
ret
=
self
.
__class__
({
ret
=
self
.
__class__
({
k
-
key
:
v
for
k
,
v
in
self
.
items
()
if
key
<=
k
})
k
-
key
:
v
for
k
,
v
in
self
.
items
()
if
key
<=
k
})
if
not
ret
:
if
not
ret
:
raise
KeyError
(
key
)
raise
KeyError
(
key
)
return
ret
return
ret
...
@@ -64,10 +60,7 @@ class SKDict(dict):
...
@@ -64,10 +60,7 @@ class SKDict(dict):
assert
all
(
isinstance
(
inst
,
cls
)
for
inst
in
insts
)
assert
all
(
isinstance
(
inst
,
cls
)
for
inst
in
insts
)
keys
=
set
()
keys
=
set
()
keys
.
update
(
*
(
inst
.
keys
()
for
inst
in
insts
))
keys
.
update
(
*
(
inst
.
keys
()
for
inst
in
insts
))
return
cls
({
return
cls
({
key
:
tuple
(
inst
.
get
(
key
)
for
inst
in
insts
)
for
key
in
keys
})
key
:
tuple
(
inst
.
get
(
key
)
for
inst
in
insts
)
for
key
in
keys
})
def
only
(
self
,
*
keys
):
def
only
(
self
,
*
keys
):
return
self
.
__class__
({
key
:
self
[
key
]
for
key
in
keys
})
return
self
.
__class__
({
key
:
self
[
key
]
for
key
in
keys
})
...
@@ -88,7 +81,7 @@ class SKDict(dict):
...
@@ -88,7 +81,7 @@ class SKDict(dict):
assert
len
(
keys
)
==
1
# bad depth
assert
len
(
keys
)
==
1
# bad depth
return
list
(
keys
)[
0
]
return
list
(
keys
)[
0
]
elif
ads
==
{
False
}:
elif
ads
==
{
False
}:
return
(),
return
(
(),
)
else
:
else
:
raise
RuntimeError
(
"
bad depth
"
)
raise
RuntimeError
(
"
bad depth
"
)
...
@@ -97,10 +90,7 @@ class SKDict(dict):
...
@@ -97,10 +90,7 @@ class SKDict(dict):
@property
@property
def
pretty
(
self
):
def
pretty
(
self
):
return
{
return
{
"
/
"
.
join
(
sorted
(
map
(
str
,
k
))):
v
for
k
,
v
in
self
.
items
()}
"
/
"
.
join
(
sorted
(
map
(
str
,
k
))):
v
for
k
,
v
in
self
.
items
()
}
class
GetNextSlice
(
object
):
class
GetNextSlice
(
object
):
...
@@ -113,7 +103,7 @@ class GetNextSlice(object):
...
@@ -113,7 +103,7 @@ class GetNextSlice(object):
if
self
.
curr
is
None
:
if
self
.
curr
is
None
:
self
.
curr
=
self
.
next
()
self
.
curr
=
self
.
next
()
self
.
pos
=
0
self
.
pos
=
0
sli
=
self
.
curr
[
self
.
pos
:
self
.
pos
+
num
]
sli
=
self
.
curr
[
self
.
pos
:
self
.
pos
+
num
]
self
.
pos
+=
num
self
.
pos
+=
num
if
len
(
sli
)
<
num
:
if
len
(
sli
)
<
num
:
del
self
.
curr
del
self
.
curr
...
@@ -135,39 +125,48 @@ class DSS(SKDict):
...
@@ -135,39 +125,48 @@ class DSS(SKDict):
assert
len
(
lens
)
==
1
assert
len
(
lens
)
==
1
return
lens
[
0
]
return
lens
[
0
]
@property
def
dtype
(
self
):
dtypes
=
list
(
set
(
val
.
dtype
for
val
in
self
.
values
()))
assert
len
(
dtypes
)
==
1
return
dtypes
[
0
]
@property
def
dims
(
self
):
dimss
=
list
(
set
(
val
.
ndim
for
val
in
self
.
values
()))
assert
len
(
dimss
)
==
1
return
dimss
[
0
]
@property
def
shape
(
self
):
shapes
=
list
(
set
(
val
.
shape
for
val
in
self
.
values
()))
if
len
(
shapes
)
>
1
:
assert
set
(
map
(
len
,
shapes
))
==
{
self
.
dims
}
return
tuple
(
s
[
0
]
if
len
(
s
)
==
1
else
None
for
s
in
map
(
list
,
map
(
set
,
zip
(
*
shapes
))))
return
shapes
[
0
]
def
fuse
(
self
,
*
keys
,
**
kwargs
):
def
fuse
(
self
,
*
keys
,
**
kwargs
):
op
=
kwargs
.
pop
(
"
op
"
,
np
.
concatenate
)
op
=
kwargs
.
pop
(
"
op
"
,
np
.
concatenate
)
assert
not
kwargs
assert
not
kwargs
return
self
.
zip
(
*
(
return
self
.
zip
(
*
(
self
[
self
.
keyify
(
key
)]
for
key
in
keys
)).
map
(
op
)
self
[
self
.
keyify
(
key
)]
for
key
in
keys
)).
map
(
op
)
def
split
(
self
,
thresh
,
right
=
False
,
rng
=
np
.
random
):
def
split
(
self
,
thresh
,
right
=
False
,
rng
=
np
.
random
):
if
isinstance
(
thresh
,
int
):
if
isinstance
(
thresh
,
int
):
thresh
=
np
.
linspace
(
0
,
1
,
num
=
thresh
+
1
)[
1
:
-
1
]
thresh
=
np
.
linspace
(
0
,
1
,
num
=
thresh
+
1
)[
1
:
-
1
]
if
isinstance
(
thresh
,
float
):
if
isinstance
(
thresh
,
float
):
thresh
=
thresh
,
thresh
=
(
thresh
,
)
thresh
=
np
.
array
(
thresh
)
thresh
=
np
.
array
(
thresh
)
assert
np
.
all
((
0
<
thresh
)
&
(
thresh
<
1
))
assert
np
.
all
((
0
<
thresh
)
&
(
thresh
<
1
))
idx
=
np
.
digitize
(
rng
.
uniform
(
size
=
self
.
blen
),
thresh
,
right
=
right
)
idx
=
np
.
digitize
(
rng
.
uniform
(
size
=
self
.
blen
),
thresh
,
right
=
right
)
return
tuple
(
return
tuple
(
self
.
map
(
itemgetter
(
idx
==
i
))
for
i
in
range
(
len
(
thresh
)
+
1
))
self
.
map
(
itemgetter
(
idx
==
i
))
for
i
in
range
(
len
(
thresh
)
+
1
)
)
def
shuffle
(
self
,
rng
=
np
.
random
):
def
shuffle
(
self
,
rng
=
np
.
random
):
return
self
.
map
(
itemgetter
(
rng
.
permutation
(
self
.
blen
)))
return
self
.
map
(
itemgetter
(
rng
.
permutation
(
self
.
blen
)))
def
gen_feed_dict
(
self
,
tensor2key
,
batch_size
=
1024
):
def
gen_feed_dict
(
self
,
tensor2key
,
batch_size
=
1024
):
for
sli
in
self
.
batch_slices
(
batch_size
):
for
sli
in
self
.
batch_slices
(
batch_size
):
buf
=
{
buf
=
{
key
:
self
[
key
][
sli
]
for
key
in
set
(
tensor2key
.
values
())}
key
:
self
[
key
][
sli
]
yield
{
tensor
:
buf
[
key
]
for
tensor
,
key
in
tensor2key
.
items
()}
for
key
in
set
(
tensor2key
.
values
())
}
yield
{
tensor
:
buf
[
key
]
for
tensor
,
key
in
tensor2key
.
items
()
}
def
batch_slices
(
self
,
batch_size
):
def
batch_slices
(
self
,
batch_size
):
for
i
in
range
(
0
,
self
.
blen
,
batch_size
):
for
i
in
range
(
0
,
self
.
blen
,
batch_size
):
...
@@ -191,9 +190,7 @@ class DSS(SKDict):
...
@@ -191,9 +190,7 @@ class DSS(SKDict):
getter
=
itemgetter
(
x
,
y
,
w
)
getter
=
itemgetter
(
x
,
y
,
w
)
train
,
valid
=
self
[
"
train
"
],
self
[
"
valid
"
]
train
,
valid
=
self
[
"
train
"
],
self
[
"
valid
"
]
return
dict
(
return
dict
(
zip
([
"
x
"
,
"
y
"
,
"
sample_weight
"
],
getter
(
train
)),
zip
([
"
x
"
,
"
y
"
,
"
sample_weight
"
],
getter
(
train
)),
validation_data
=
getter
(
valid
),
**
kwargs
validation_data
=
getter
(
valid
),
**
kwargs
)
)
def
balanced
(
self
,
*
keys
,
**
kwargs
):
def
balanced
(
self
,
*
keys
,
**
kwargs
):
...
@@ -207,18 +204,17 @@ class DSS(SKDict):
...
@@ -207,18 +204,17 @@ class DSS(SKDict):
s
=
np
.
sum
(
s
.
values
())
s
=
np
.
sum
(
s
.
values
())
sums
[
key
]
=
s
sums
[
key
]
=
s
ref
=
kref
(
sums
.
values
())
if
callable
(
kref
)
else
sums
[
kref
]
ref
=
kref
(
sums
.
values
())
if
callable
(
kref
)
else
sums
[
kref
]
return
self
.
__class__
({
return
self
.
__class__
({
k
:
self
[
k
].
map
(
lambda
x
:
x
*
(
ref
/
s
))
for
k
,
s
in
sums
.
items
()})
k
:
self
[
k
].
map
(
lambda
x
:
x
*
(
ref
/
s
))
for
k
,
s
in
sums
.
items
()
})
@classmethod
@classmethod
def
from_npy
(
cls
,
dir
,
sep
=
"
_
"
,
**
kwargs
):
def
from_npy
(
cls
,
dir
,
sep
=
"
_
"
,
**
kwargs
):
return
cls
({
return
cls
(
tuple
(
fn
[:
-
4
].
split
(
sep
)):
np
.
load
(
path
.
join
(
dir
,
fn
),
**
kwargs
)
{
for
fn
in
listdir
(
dir
)
tuple
(
fn
[:
-
4
].
split
(
sep
)):
np
.
load
(
path
.
join
(
dir
,
fn
),
**
kwargs
)
if
fn
.
endswith
(
"
.npy
"
)
for
fn
in
listdir
(
dir
)
})
if
fn
.
endswith
(
"
.npy
"
)
}
)
def
to_npy
(
self
,
dir
,
sep
=
"
_
"
,
**
kwargs
):
def
to_npy
(
self
,
dir
,
sep
=
"
_
"
,
**
kwargs
):
for
key
,
value
in
self
.
items
():
for
key
,
value
in
self
.
items
():
...
...
This diff is collapsed.
Click to expand it.
keras.py
+
334
−
91
View file @
a18a0e81
This diff is collapsed.
Click to expand it.
plotting.py
0 → 100644
+
122
−
0
View file @
a18a0e81
# -*- coding: utf-8 -*-
import
io
import
tensorflow
as
tf
import
itertools
import
numpy
as
np
from
matplotlib
import
pyplot
as
plt
from
sklearn.metrics
import
confusion_matrix
class
Quadrature
:
def
__init__
(
self
,
n
):
self
.
n
=
n
self
.
cols
=
np
.
ceil
(
np
.
sqrt
(
n
)).
astype
(
int
)
self
.
rows
=
np
.
ceil
(
n
/
self
.
cols
).
astype
(
int
)
def
lenghts
(
self
):
return
self
.
rows
,
self
.
cols
def
index
(
self
,
n
):
row
=
int
(
n
/
self
.
cols
)
col
=
n
-
row
*
self
.
cols
return
row
,
col
def
figure_confusion_matrix
(
truth
,
prediction
,
class_names
=
[
"
signal
"
,
"
background
"
],
sample_weight
=
None
,
normalize
=
"
true
"
,
**
kwargs
):
assert
len
(
class_names
)
==
truth
.
shape
[
-
1
]
==
prediction
.
shape
[
-
1
]
fig
,
ax
=
plt
.
subplots
()
cm
=
confusion_matrix
(
np
.
argmax
(
truth
,
axis
=-
1
),
np
.
argmax
(
prediction
,
axis
=-
1
),
sample_weight
=
sample_weight
,
normalize
=
normalize
,
)
cmap
=
"
plasma
"
if
normalize
==
"
true
"
else
"
viridis
"
im
=
ax
.
imshow
(
cm
,
interpolation
=
"
nearest
"
,
cmap
=
cmap
)
plt
.
title
(
"
Confusion matrix
"
)
fig
.
colorbar
(
im
,
ax
=
ax
)
tick_marks
=
np
.
arange
(
len
(
class_names
))
plt
.
xticks
(
tick_marks
,
class_names
,
rotation
=
45
)
plt
.
yticks
(
tick_marks
,
class_names
)
for
i
,
j
in
itertools
.
product
(
range
(
cm
.
shape
[
0
]),
range
(
cm
.
shape
[
1
])):
plt
.
text
(
j
,
i
,
np
.
around
(
cm
[
i
,
j
],
decimals
=
2
),
horizontalalignment
=
"
center
"
,
size
=
7
)
plt
.
ylabel
(
"
True label
"
)
plt
.
xlabel
(
"
Predicted label
"
)
fig
.
tight_layout
()
return
fig
def
figure_activations
(
activations
,
class_names
=
None
):
bins
=
np
.
linspace
(
0
,
1.0
,
10
)
n_b
,
n_p
=
activations
.
shape
fig
=
plt
.
figure
()
for
i
in
range
(
n_p
):
plt
.
hist
(
activations
[:,
i
],
bins
,
histtype
=
u
"
step
"
,
density
=
True
,
label
=
"
%i
"
%
i
if
class_names
is
None
else
class_names
[
i
],
)
plt
.
yscale
(
"
log
"
)
plt
.
legend
()
fig
.
tight_layout
()
return
fig
def
figure_node_activations
(
activations
,
truth
,
class_names
=
None
):
n_b
,
n_p
=
activations
.
shape
quad
=
Quadrature
(
n_p
)
rows
,
cols
=
quad
.
lenghts
()
fig
,
ax
=
plt
.
subplots
(
rows
,
cols
,
figsize
=
(
15
,
15
*
rows
/
cols
))
bins
=
np
.
linspace
(
0
,
1.0
,
10
)
process_activations
=
[]
for
process
in
range
(
n_p
):
process_activations
.
append
(
activations
[
truth
[:,
process
]].
swapaxes
(
0
,
1
))
for
node
in
range
(
n_p
):
for
process
in
range
(
n_p
):
ax
[
quad
.
index
(
node
)].
hist
(
process_activations
[
process
][
node
],
bins
,
histtype
=
u
"
step
"
,
density
=
True
,
label
=
"
%i
"
%
process
if
class_names
is
None
else
class_names
[
process
],
)
ax
[
quad
.
index
(
node
)].
set_yscale
(
"
log
"
)
ax
[
quad
.
index
(
cols
-
1
)].
legend
(
bbox_to_anchor
=
(
1.05
,
1.0
),
loc
=
"
upper left
"
)
fig
.
tight_layout
()
return
fig
def
figure_to_image
(
figure
):
"""
Converts the matplotlib plot specified by
'
figure
'
to a PNG image and
returns it. The supplied figure is closed and inaccessible after this call.
"""
# Save the plot to a PNG in memory.
buf
=
io
.
BytesIO
()
plt
.
savefig
(
buf
,
format
=
"
png
"
)
# Closing the figure prevents it from being displayed directly inside
# the notebook.
plt
.
close
(
figure
)
buf
.
seek
(
0
)
# Convert PNG buffer to TF image
image
=
tf
.
image
.
decode_png
(
buf
.
getvalue
(),
channels
=
4
)
# Add the batch dimension
image
=
tf
.
expand_dims
(
image
,
0
)
return
image
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment