Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
MRCNN Particle Detection
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
AVT-FVT
public
MRCNN Particle Detection
Commits
2e017eda
Commit
2e017eda
authored
1 year ago
by
Stepan Sibirtsev
Browse files
Options
Downloads
Patches
Plain Diff
Upload New File
parent
3cb113fe
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
mrcnn/visualize.py
+540
-0
540 additions, 0 deletions
mrcnn/visualize.py
with
540 additions
and
0 deletions
mrcnn/visualize.py
0 → 100644
+
540
−
0
View file @
2e017eda
"""
MRCNN Particle Detection
Display and Visualization Functions.
The source code of
"
MRCNN Particle Detection
"
(https://git.rwth-aachen.de/avt-fvt/private/mrcnn-particle-detection)
is based on the source code of
"
Mask R-CNN
"
(https://github.com/matterport/Mask_RCNN).
The source code of
"
Mask R-CNN
"
is licensed under the MIT License (MIT).
Copyright (c) 2017 Matterport, Inc.
Written by Waleed Abdulla
All source code modifications to the source code of
"
Mask R-CNN
"
in
"
MRCNN Particle Detection
"
are licensed under the Eclipse Public License v2.0 (EPL 2.0).
Copyright (c) 2022-2023 Fluid Process Engineering (AVT.FVT), RWTH Aachen University
Edited by Stepan Sibirtsev, Mathias Neufang & Jakob Seiler
The coyprights and license terms are given in LICENSE.
Ideas and a small code snippets were adapted from these sources:
https://github.com/mat02/Mask_RCNN
"""
import
os
import
sys
import
random
import
itertools
import
colorsys
import
numpy
as
np
from
skimage.measure
import
find_contours
import
matplotlib.pyplot
as
plt
from
matplotlib
import
patches
,
lines
from
matplotlib.patches
import
Polygon
import
IPython.display
# Root directory of the project
ROOT_DIR
=
os
.
path
.
abspath
(
"
../
"
)
# Import Mask RCNN
sys
.
path
.
append
(
ROOT_DIR
)
# To find local version of the library
from
mrcnn
import
utils
############################################################
# Visualization
############################################################
def
display_images
(
images
,
titles
=
None
,
cols
=
4
,
cmap
=
None
,
norm
=
None
,
interpolation
=
None
):
"""
Display the given set of images, optionally with titles.
images: list or array of image tensors in HWC format.
titles: optional. A list of titles to display with each image.
cols: number of images per row
cmap: Optional. Color map to use. For example,
"
Blues
"
.
norm: Optional. A Normalize instance to map values to colors.
interpolation: Optional. Image interpolation to use for display.
"""
titles
=
titles
if
titles
is
not
None
else
[
""
]
*
len
(
images
)
rows
=
len
(
images
)
//
cols
+
1
plt
.
figure
(
figsize
=
(
14
,
14
*
rows
//
cols
))
i
=
1
for
image
,
title
in
zip
(
images
,
titles
):
plt
.
subplot
(
rows
,
cols
,
i
)
plt
.
title
(
title
,
fontsize
=
9
)
plt
.
axis
(
'
off
'
)
plt
.
imshow
(
image
.
astype
(
np
.
uint8
),
cmap
=
cmap
,
norm
=
norm
,
interpolation
=
interpolation
)
i
+=
1
plt
.
show
()
def
random_colors
(
N
,
bright
=
True
):
"""
Generate random colors.
To get visually distinct colors, generate them in HSV space then
convert to RGB.
"""
brightness
=
1.0
if
bright
else
0.7
hsv
=
[(
i
/
N
,
1
,
brightness
)
for
i
in
range
(
N
)]
colors
=
list
(
map
(
lambda
c
:
colorsys
.
hsv_to_rgb
(
*
c
),
hsv
))
random
.
shuffle
(
colors
)
return
colors
def
apply_mask
(
image
,
mask
,
color
,
alpha
=
0.5
):
"""
Apply the given mask to the image.
"""
for
c
in
range
(
3
):
image
[:,
:,
c
]
=
np
.
where
(
mask
==
1
,
image
[:,
:,
c
]
*
(
1
-
alpha
)
+
alpha
*
color
[
c
]
*
255
,
image
[:,
:,
c
])
return
image
def
display_instances
(
image
,
boxes
,
masks
,
class_ids
,
class_names
,
scores
=
None
,
title
=
None
,
figsize
=
(
16
,
16
),
ax
=
None
,
show_mask
=
True
,
show_bbox
=
True
,
colors
=
None
,
captions
=
None
,
save_img
=
False
,
save_dir
=
None
,
img_name
=
None
,
number_saved_images
=
None
,
counter_1
=
None
):
# Dennis: save-function implemented
"""
boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates.
masks: [height, width, num_instances]
class_ids: [num_instances]
class_names: list of class names of the dataset
scores: (optional) confidence scores for each box
title: (optional) Figure title
show_mask, show_bbox: To show masks and bounding boxes or not
figsize: (optional) the size of the image
colors: (optional) An array or colors to use with each object
captions: (optional) A list of strings to use as captions for each object
save_img: To save the predict image
save_dir: If save_img is True, the directory where you want to save the predict image
img_name: If save_img is True, the name of the predict image
"""
# Number of instances
N
=
boxes
.
shape
[
0
]
if
not
N
:
print
(
"
\n
*** No instances to display ***
\n
"
)
else
:
# distinction between with or without mask detection
if
show_bbox
and
show_mask
:
assert
boxes
.
shape
[
0
]
==
masks
.
shape
[
-
1
]
==
class_ids
.
shape
[
0
]
elif
show_bbox
and
not
show_mask
:
assert
boxes
.
shape
[
0
]
==
class_ids
.
shape
[
0
]
elif
not
show_bbox
and
show_mask
:
assert
masks
.
shape
[
-
1
]
==
class_ids
.
shape
[
0
]
# If no axis is passed, create one and automatically call show()
auto_show
=
False
if
not
ax
:
_
,
ax
=
plt
.
subplots
(
1
,
figsize
=
figsize
)
auto_show
=
True
# Generate random colors
colors
=
colors
or
random_colors
(
N
)
# Show area outside image boundaries.
#height, width = image.shape[:2]
#ax.set_ylim(height + 10, -10)
#ax.set_xlim(-10, width + 10)
ax
.
axis
(
'
off
'
)
#ax.set_title(title)
masked_image
=
image
.
astype
(
np
.
uint32
).
copy
()
for
i
in
range
(
N
):
color
=
colors
[
i
]
# Bounding box
if
not
np
.
any
(
boxes
[
i
]):
# Skip this instance. Has no bbox. Likely lost in image cropping.
continue
y1
,
x1
,
y2
,
x2
=
boxes
[
i
]
if
show_bbox
:
p
=
patches
.
Rectangle
((
x1
,
y1
),
x2
-
x1
,
y2
-
y1
,
linewidth
=
3
,
alpha
=
0.7
,
linestyle
=
"
dashed
"
,
edgecolor
=
color
,
facecolor
=
'
none
'
)
ax
.
add_patch
(
p
)
# Label
if
not
captions
:
class_id
=
class_ids
[
i
]
score
=
scores
[
i
]
if
scores
is
not
None
else
None
label
=
class_names
[
class_id
]
caption
=
"
{} {:.3f}
"
.
format
(
label
,
score
)
if
score
else
label
else
:
caption
=
str
(
captions
[
i
])
+
"
mm
"
ax
.
text
(
x1
,
y1
+
8
,
caption
,
color
=
'
w
'
,
size
=
11
,
backgroundcolor
=
"
none
"
)
# Mask
# Add mask detection
if
show_mask
:
mask
=
masks
[:,
:,
i
]
masked_image
=
apply_mask
(
masked_image
,
mask
,
color
)
# Mask Polygon
# Pad to ensure proper polygons for masks that touch image edges.
padded_mask
=
np
.
zeros
(
(
mask
.
shape
[
0
]
+
2
,
mask
.
shape
[
1
]
+
2
),
dtype
=
np
.
uint8
)
padded_mask
[
1
:
-
1
,
1
:
-
1
]
=
mask
contours
=
find_contours
(
padded_mask
,
0.5
)
for
verts
in
contours
:
# Subtract the padding and flip (y, x) to (x, y)
verts
=
np
.
fliplr
(
verts
)
-
1
p
=
Polygon
(
verts
,
facecolor
=
"
none
"
,
edgecolor
=
color
)
ax
.
add_patch
(
p
)
ax
.
imshow
(
masked_image
.
astype
(
np
.
uint8
))
if
save_img
:
#choose = round(number_saved_images * random.random()) # Dennis: image with detection save
if
(
counter_1
==
number_saved_images
):
plt
.
savefig
(
os
.
path
.
join
(
save_dir
,
img_name
+
'
.jpg
'
),
bbox_inches
=
'
tight
'
,
pad_inches
=
False
,
orientation
=
'
landscape
'
)
# Dennis: image with detection save
plt
.
close
()
# else: # Dennis: image with detection save
# continue # Dennis: image with detection save
# if auto_show:
# plt.show()
# plt.annotate('25, 50', xy=(25, 40), xycoords='data',
# xytext=(0.5, 0.5), textcoords='figure fraction',
# arrowprops=dict(arrowstyle="->"))
#if save_img: # Dennis: save function
# cv2.imwrite(os.path.join(save_dir, img_name), image) # Dennis: save function
def
display_differences
(
image
,
gt_box
,
gt_class_id
,
gt_mask
,
pred_box
,
pred_class_id
,
pred_score
,
pred_mask
,
class_names
,
title
=
""
,
ax
=
None
,
show_mask
=
True
,
show_box
=
True
,
iou_threshold
=
0.5
,
score_threshold
=
0.5
):
"""
Display ground truth and prediction instances on the same image.
"""
# Match predictions to ground truth
gt_match
,
pred_match
,
overlaps
=
utils
.
compute_matches
(
gt_box
,
gt_class_id
,
gt_mask
,
pred_box
,
pred_class_id
,
pred_score
,
pred_mask
,
iou_threshold
=
iou_threshold
,
score_threshold
=
score_threshold
)
# Ground truth = green. Predictions = red
colors
=
[(
0
,
1
,
0
,
.
8
)]
*
len
(
gt_match
)
\
+
[(
1
,
0
,
0
,
1
)]
*
len
(
pred_match
)
# Concatenate GT and predictions
class_ids
=
np
.
concatenate
([
gt_class_id
,
pred_class_id
])
scores
=
np
.
concatenate
([
np
.
zeros
([
len
(
gt_match
)]),
pred_score
])
boxes
=
np
.
concatenate
([
gt_box
,
pred_box
])
masks
=
np
.
concatenate
([
gt_mask
,
pred_mask
],
axis
=-
1
)
# Captions per instance show score/IoU
captions
=
[
""
for
m
in
gt_match
]
+
[
"
{:.2f} / {:.2f}
"
.
format
(
pred_score
[
i
],
(
overlaps
[
i
,
int
(
pred_match
[
i
])]
if
pred_match
[
i
]
>
-
1
else
overlaps
[
i
].
max
()))
for
i
in
range
(
len
(
pred_match
))]
# Set title if not provided
title
=
title
or
"
Ground Truth and Detections
\n
GT=green, pred=red, captions: score/IoU
"
# Display
display_instances
(
image
,
boxes
,
masks
,
class_ids
,
class_names
,
scores
,
ax
=
ax
,
show_bbox
=
show_box
,
show_mask
=
show_mask
,
colors
=
colors
,
captions
=
captions
,
title
=
title
)
def
draw_rois
(
image
,
rois
,
refined_rois
,
mask
,
class_ids
,
class_names
,
limit
=
10
):
"""
anchors: [n, (y1, x1, y2, x2)] list of anchors in image coordinates.
proposals: [n, 4] the same anchors but refined to fit objects better.
"""
masked_image
=
image
.
copy
()
# Pick random anchors in case there are too many.
ids
=
np
.
arange
(
rois
.
shape
[
0
],
dtype
=
np
.
int32
)
ids
=
np
.
random
.
choice
(
ids
,
limit
,
replace
=
False
)
if
ids
.
shape
[
0
]
>
limit
else
ids
fig
,
ax
=
plt
.
subplots
(
1
,
figsize
=
(
12
,
12
))
if
rois
.
shape
[
0
]
>
limit
:
plt
.
title
(
"
Showing {} random ROIs out of {}
"
.
format
(
len
(
ids
),
rois
.
shape
[
0
]))
else
:
plt
.
title
(
"
{} ROIs
"
.
format
(
len
(
ids
)))
# Show area outside image boundaries.
ax
.
set_ylim
(
image
.
shape
[
0
]
+
20
,
-
20
)
ax
.
set_xlim
(
-
50
,
image
.
shape
[
1
]
+
20
)
ax
.
axis
(
'
off
'
)
for
i
,
id
in
enumerate
(
ids
):
color
=
np
.
random
.
rand
(
3
)
class_id
=
class_ids
[
id
]
# ROI
y1
,
x1
,
y2
,
x2
=
rois
[
id
]
p
=
patches
.
Rectangle
((
x1
,
y1
),
x2
-
x1
,
y2
-
y1
,
linewidth
=
2
,
edgecolor
=
color
if
class_id
else
"
gray
"
,
facecolor
=
'
none
'
,
linestyle
=
"
dashed
"
)
ax
.
add_patch
(
p
)
# Refined ROI
if
class_id
:
ry1
,
rx1
,
ry2
,
rx2
=
refined_rois
[
id
]
p
=
patches
.
Rectangle
((
rx1
,
ry1
),
rx2
-
rx1
,
ry2
-
ry1
,
linewidth
=
2
,
edgecolor
=
color
,
facecolor
=
'
none
'
)
ax
.
add_patch
(
p
)
# Connect the top-left corners of the anchor and proposal for easy visualization
ax
.
add_line
(
lines
.
Line2D
([
x1
,
rx1
],
[
y1
,
ry1
],
color
=
color
))
# Label
label
=
class_names
[
class_id
]
ax
.
text
(
rx1
,
ry1
+
8
,
"
{}
"
.
format
(
label
),
color
=
'
w
'
,
size
=
11
,
backgroundcolor
=
"
none
"
)
# Mask
m
=
utils
.
unmold_mask
(
mask
[
id
],
rois
[
id
]
[:
4
].
astype
(
np
.
int32
),
image
.
shape
)
masked_image
=
apply_mask
(
masked_image
,
m
,
color
)
ax
.
imshow
(
masked_image
)
# Print stats
print
(
"
Positive ROIs:
"
,
class_ids
[
class_ids
>
0
].
shape
[
0
])
print
(
"
Negative ROIs:
"
,
class_ids
[
class_ids
==
0
].
shape
[
0
])
print
(
"
Positive Ratio: {:.2f}
"
.
format
(
class_ids
[
class_ids
>
0
].
shape
[
0
]
/
class_ids
.
shape
[
0
]))
# TODO: Replace with matplotlib equivalent?
def
draw_box
(
image
,
box
,
color
):
"""
Draw 3-pixel width bounding boxes on the given image array.
color: list of 3 int values for RGB.
"""
y1
,
x1
,
y2
,
x2
=
box
image
[
y1
:
y1
+
2
,
x1
:
x2
]
=
color
image
[
y2
:
y2
+
2
,
x1
:
x2
]
=
color
image
[
y1
:
y2
,
x1
:
x1
+
2
]
=
color
image
[
y1
:
y2
,
x2
:
x2
+
2
]
=
color
return
image
def
display_top_masks
(
image
,
mask
,
class_ids
,
class_names
,
limit
=
4
):
"""
Display the given image and the top few class masks.
"""
to_display
=
[]
titles
=
[]
to_display
.
append
(
image
)
titles
.
append
(
"
H x W={}x{}
"
.
format
(
image
.
shape
[
0
],
image
.
shape
[
1
]))
# Pick top prominent classes in this image
unique_class_ids
=
np
.
unique
(
class_ids
)
mask_area
=
[
np
.
sum
(
mask
[:,
:,
np
.
where
(
class_ids
==
i
)[
0
]])
for
i
in
unique_class_ids
]
top_ids
=
[
v
[
0
]
for
v
in
sorted
(
zip
(
unique_class_ids
,
mask_area
),
key
=
lambda
r
:
r
[
1
],
reverse
=
True
)
if
v
[
1
]
>
0
]
# Generate images and titles
for
i
in
range
(
limit
):
class_id
=
top_ids
[
i
]
if
i
<
len
(
top_ids
)
else
-
1
# Pull masks of instances belonging to the same class.
m
=
mask
[:,
:,
np
.
where
(
class_ids
==
class_id
)[
0
]]
m
=
np
.
sum
(
m
*
np
.
arange
(
1
,
m
.
shape
[
-
1
]
+
1
),
-
1
)
to_display
.
append
(
m
)
titles
.
append
(
class_names
[
class_id
]
if
class_id
!=
-
1
else
"
-
"
)
display_images
(
to_display
,
titles
=
titles
,
cols
=
limit
+
1
,
cmap
=
"
Blues_r
"
)
def
plot_precision_recall
(
AP
,
precisions
,
recalls
):
"""
Draw the precision-recall curve.
AP: Average precision at IoU >= 0.5
precisions: list of precision values
recalls: list of recall values
"""
# Plot the Precision-Recall curve
_
,
ax
=
plt
.
subplots
(
1
)
ax
.
set_title
(
"
Precision-Recall Curve. AP@50 = {:.3f}
"
.
format
(
AP
))
ax
.
set_ylim
(
0
,
1.1
)
ax
.
set_xlim
(
0
,
1.1
)
_
=
ax
.
plot
(
recalls
,
precisions
)
def
plot_overlaps
(
gt_class_ids
,
pred_class_ids
,
pred_scores
,
overlaps
,
class_names
,
threshold
=
0.5
):
"""
Draw a grid showing how ground truth objects are classified.
gt_class_ids: [N] int. Ground truth class IDs
pred_class_id: [N] int. Predicted class IDs
pred_scores: [N] float. The probability scores of predicted classes
overlaps: [pred_boxes, gt_boxes] IoU overlaps of predictions and GT boxes.
class_names: list of all class names in the dataset
threshold: Float. The prediction probability required to predict a class
"""
gt_class_ids
=
gt_class_ids
[
gt_class_ids
!=
0
]
pred_class_ids
=
pred_class_ids
[
pred_class_ids
!=
0
]
plt
.
figure
(
figsize
=
(
12
,
10
))
plt
.
imshow
(
overlaps
,
interpolation
=
'
nearest
'
,
cmap
=
plt
.
cm
.
Blues
)
plt
.
yticks
(
np
.
arange
(
len
(
pred_class_ids
)),
[
"
{} ({:.2f})
"
.
format
(
class_names
[
int
(
id
)],
pred_scores
[
i
])
for
i
,
id
in
enumerate
(
pred_class_ids
)])
plt
.
xticks
(
np
.
arange
(
len
(
gt_class_ids
)),
[
class_names
[
int
(
id
)]
for
id
in
gt_class_ids
],
rotation
=
90
)
thresh
=
overlaps
.
max
()
/
2.
for
i
,
j
in
itertools
.
product
(
range
(
overlaps
.
shape
[
0
]),
range
(
overlaps
.
shape
[
1
])):
text
=
""
if
overlaps
[
i
,
j
]
>
threshold
:
text
=
"
match
"
if
gt_class_ids
[
j
]
==
pred_class_ids
[
i
]
else
"
wrong
"
color
=
(
"
white
"
if
overlaps
[
i
,
j
]
>
thresh
else
"
black
"
if
overlaps
[
i
,
j
]
>
0
else
"
grey
"
)
plt
.
text
(
j
,
i
,
"
{:.3f}
\n
{}
"
.
format
(
overlaps
[
i
,
j
],
text
),
horizontalalignment
=
"
center
"
,
verticalalignment
=
"
center
"
,
fontsize
=
9
,
color
=
color
)
plt
.
tight_layout
()
plt
.
xlabel
(
"
Ground Truth
"
)
plt
.
ylabel
(
"
Predictions
"
)
def
draw_boxes
(
image
,
boxes
=
None
,
refined_boxes
=
None
,
masks
=
None
,
captions
=
None
,
visibilities
=
None
,
title
=
""
,
ax
=
None
):
"""
Draw bounding boxes and segmentation masks with different
customizations.
boxes: [N, (y1, x1, y2, x2, class_id)] in image coordinates.
refined_boxes: Like boxes, but draw with solid lines to show
that they
'
re the result of refining
'
boxes
'
.
masks: [N, height, width]
captions: List of N titles to display on each box
visibilities: (optional) List of values of 0, 1, or 2. Determine how
prominent each bounding box should be.
title: An optional title to show over the image
ax: (optional) Matplotlib axis to draw on.
"""
# Number of boxes
assert
boxes
is
not
None
or
refined_boxes
is
not
None
N
=
boxes
.
shape
[
0
]
if
boxes
is
not
None
else
refined_boxes
.
shape
[
0
]
# Matplotlib Axis
if
not
ax
:
_
,
ax
=
plt
.
subplots
(
1
,
figsize
=
(
12
,
12
))
# Generate random colors
colors
=
random_colors
(
N
)
# Show area outside image boundaries.
margin
=
image
.
shape
[
0
]
//
10
ax
.
set_ylim
(
image
.
shape
[
0
]
+
margin
,
-
margin
)
ax
.
set_xlim
(
-
margin
,
image
.
shape
[
1
]
+
margin
)
ax
.
axis
(
'
off
'
)
ax
.
set_title
(
title
)
masked_image
=
image
.
astype
(
np
.
uint32
).
copy
()
for
i
in
range
(
N
):
# Box visibility
visibility
=
visibilities
[
i
]
if
visibilities
is
not
None
else
1
if
visibility
==
0
:
color
=
"
gray
"
style
=
"
dotted
"
alpha
=
0.5
elif
visibility
==
1
:
color
=
colors
[
i
]
style
=
"
dotted
"
alpha
=
1
elif
visibility
==
2
:
color
=
colors
[
i
]
style
=
"
solid
"
alpha
=
1
# Boxes
if
boxes
is
not
None
:
if
not
np
.
any
(
boxes
[
i
]):
# Skip this instance. Has no bbox. Likely lost in cropping.
continue
y1
,
x1
,
y2
,
x2
=
boxes
[
i
]
p
=
patches
.
Rectangle
((
x1
,
y1
),
x2
-
x1
,
y2
-
y1
,
linewidth
=
2
,
alpha
=
alpha
,
linestyle
=
style
,
edgecolor
=
color
,
facecolor
=
'
none
'
)
ax
.
add_patch
(
p
)
# Refined boxes
if
refined_boxes
is
not
None
and
visibility
>
0
:
ry1
,
rx1
,
ry2
,
rx2
=
refined_boxes
[
i
].
astype
(
np
.
int32
)
p
=
patches
.
Rectangle
((
rx1
,
ry1
),
rx2
-
rx1
,
ry2
-
ry1
,
linewidth
=
2
,
edgecolor
=
color
,
facecolor
=
'
none
'
)
ax
.
add_patch
(
p
)
# Connect the top-left corners of the anchor and proposal
if
boxes
is
not
None
:
ax
.
add_line
(
lines
.
Line2D
([
x1
,
rx1
],
[
y1
,
ry1
],
color
=
color
))
# Captions
if
captions
is
not
None
:
caption
=
captions
[
i
]
# If there are refined boxes, display captions on them
if
refined_boxes
is
not
None
:
y1
,
x1
,
y2
,
x2
=
ry1
,
rx1
,
ry2
,
rx2
ax
.
text
(
x1
,
y1
,
caption
,
size
=
11
,
verticalalignment
=
'
top
'
,
color
=
'
w
'
,
backgroundcolor
=
"
none
"
,
bbox
=
{
'
facecolor
'
:
color
,
'
alpha
'
:
0.5
,
'
pad
'
:
2
,
'
edgecolor
'
:
'
none
'
})
# Masks
if
masks
is
not
None
:
mask
=
masks
[:,
:,
i
]
masked_image
=
apply_mask
(
masked_image
,
mask
,
color
)
# Mask Polygon
# Pad to ensure proper polygons for masks that touch image edges.
padded_mask
=
np
.
zeros
(
(
mask
.
shape
[
0
]
+
2
,
mask
.
shape
[
1
]
+
2
),
dtype
=
np
.
uint8
)
padded_mask
[
1
:
-
1
,
1
:
-
1
]
=
mask
contours
=
find_contours
(
padded_mask
,
0.5
)
for
verts
in
contours
:
# Subtract the padding and flip (y, x) to (x, y)
verts
=
np
.
fliplr
(
verts
)
-
1
p
=
Polygon
(
verts
,
facecolor
=
"
none
"
,
edgecolor
=
color
)
ax
.
add_patch
(
p
)
ax
.
imshow
(
masked_image
.
astype
(
np
.
uint8
))
def
display_table
(
table
):
"""
Display values in a table format.
table: an iterable of rows, and each row is an iterable of values.
"""
html
=
""
for
row
in
table
:
row_html
=
""
for
col
in
row
:
row_html
+=
"
<td>{:40}</td>
"
.
format
(
str
(
col
))
html
+=
"
<tr>
"
+
row_html
+
"
</tr>
"
html
=
"
<table>
"
+
html
+
"
</table>
"
IPython
.
display
.
display
(
IPython
.
display
.
HTML
(
html
))
def
display_weight_stats
(
model
):
"""
Scans all the weights in the model and returns a list of tuples
that contain stats about each weight.
"""
layers
=
model
.
get_trainable_layers
()
table
=
[[
"
WEIGHT NAME
"
,
"
SHAPE
"
,
"
MIN
"
,
"
MAX
"
,
"
STD
"
]]
for
l
in
layers
:
weight_values
=
l
.
get_weights
()
# list of Numpy arrays
weight_tensors
=
l
.
weights
# list of TF tensors
for
i
,
w
in
enumerate
(
weight_values
):
weight_name
=
weight_tensors
[
i
].
name
# Detect problematic layers. Exclude biases of conv layers.
alert
=
""
if
w
.
min
()
==
w
.
max
()
and
not
(
l
.
__class__
.
__name__
==
"
Conv2D
"
and
i
==
1
):
alert
+=
"
<span style=
'
color:red
'
>*** dead?</span>
"
if
np
.
abs
(
w
.
min
())
>
1000
or
np
.
abs
(
w
.
max
())
>
1000
:
alert
+=
"
<span style=
'
color:red
'
>*** Overflow?</span>
"
# Add row
table
.
append
([
weight_name
+
alert
,
str
(
w
.
shape
),
"
{:+9.4f}
"
.
format
(
w
.
min
()),
"
{:+10.4f}
"
.
format
(
w
.
max
()),
"
{:+9.4f}
"
.
format
(
w
.
std
()),
])
display_table
(
table
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment