from matplotlib.backend_bases import MouseButton
from matplotlib.colors import BoundaryNorm
import matplotlib.pyplot as plt
import numpy as np
from skimage.segmentation import mark_boundaries
import sys
import time
import logging
def array_in_arraylist(array, arraylist):
for _i, _array in enumerate(arraylist):
if np.array_equal(array, _array):
return True
return False
def unique_arrays(lst: list):
used = []
_ = [used.append(x) for x in lst if not array_in_arraylist(x, used)]
return used
[docs]class SegmentLabeler:
[docs] def __init__(self,
img,
segmentation,
labeldict={'moho': 1, '410': 2, '660': 3, 'none': 9999},
loglevel=logging.INFO):
"""This class takes in an image, some corresponding
segmentation e.g. (SLIC) and a label dictionary that can be used to
label the image in an 'intuitive' GUI.
Examples
--------
The usage of this labeling is fairly straight forward. Given an image,
``img``, and a segmentation, ``segs``, of said image, we instantiate
the class and call the start labeling method.
>>> from lwsspy.ml import SegmentLabeler
>>> sl = SegmentLabeler(img, segments)
>>> labeled_mask = sl.start_labeling()
This will open two figures. One contains the image by itself and the
second one contains the image and the outlines of the mask. We will
use the firs image for reference and the second image to actually
label the image segments. The GUI has a few controls that can be used.
================= ======================================
Control Action
================= ======================================
Mouse-left Add label to segment
Mouse-right Remove label to segment
Mouse-left-drag Add labels to segments dragged over
Mouse-right-drag Remove labels to segments dragged over
n Next label
p Previous label True
d Delete previously labeled segment
esc Close figure and return the currently
selected mask
================= ======================================
The currently selected mask will also be returned if any of the figures
is closed.
The selected mask can then be viewed via
>>> import matplotlib.pyplot as plt
>>> imshow(labeled_mask, aspect='auto')
Note that depending on the values you chose in the label dictionary
you will have to create a colormap and norm that resembles the mask
values.
Parameters
----------
img : ndarray [w x h x 3]
Image
segmentation : ndarray [w x h]
mask that has a unique nunmber for each segment such that it can be
labeled.
labeldict : dict, optional
Dictionary of labels, must contain the 'none' keyword, which
denotes the unlabeled value,
by default {'moho': 1, '410': 2, '660': 3, 'none': 9999}
loglevel : logging.LOGLEVEL, optional
loglevel, used to debug the event loop. Not necessary to be
modified, by default logging.INFO
Notes
-----
:Authors:
Lucas Sawade (lsawade@princeton.edu)
:Last Modified:
2021.07.02 00.00 (Lucas Sawade)
"""
self.img = img
self.segmentation = segmentation
self.labeldict = labeldict
self.safetyfirst()
self.labeled = labeldict['none'] * np.ones_like(self.segmentation)
# Pick variables
self.pickhistory = dict()
self.picklabels = []
self.picknumber = []
for label, number in self.labeldict.items():
if label != 'none':
self.pickhistory[label] = []
self.picklabels.append(label)
self.picknumber.append(number)
self.activelabel = 0
self.nlabels = len(self.picklabels)
self.mouse_pressed = False
# Logging
self.loglevel = loglevel
self.__setup_logger__()
def safetyfirst(self):
if 'none' not in self.labeldict:
raise ValueError(
"The label dictionary has to contain 'none' keywords.")
def __setup_logger__(self):
# create logger
self.logger = logging.getLogger('SegmentationLabeler')
self.logger.setLevel(self.loglevel)
# create console handler and set level to debug
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.DEBUG)
# create formatter
formatter = logging.Formatter(
'[%(asctime)s] %(name)s | %(levelname)8s: '
'%(message)s (%(filename)s:%(lineno)d)',
datefmt='%m/%d/%Y %I:%M:%S %p')
# add formatter to ch
ch.setFormatter(formatter)
# add ch to logger
if len(self.logger.handlers) > 0:
self.logger.handlers = []
self.logger.addHandler(ch)
# 'application' code
self.logger.debug('debug message')
self.logger.info('info message')
self.logger.warning('warn message')
self.logger.error('error message')
self.logger.critical('critical message')
def plot_figure(self):
self.logger.debug(f'Plotting Image & Segmentation...')
self.static_figure = plt.figure(figsize=(6, 6))
self.static_ax = plt.subplot(111)
self.static_ax.imshow(self.img, aspect='auto')
self.segment_figure = plt.figure(figsize=(6, 6))
self.segment_ax = plt.subplot(111)
self.segment_ax.imshow(
mark_boundaries(self.img, self.segmentation, color=(0, 0, 0)),
aspect='auto')
self.plot_labeled_image()
def plot_labeled_image(self):
self.logger.debug(f'Plotting Labeled Image ...')
# Get boundary color norm based on label numbers
pickarray = np.array(self.picknumber)
dpickarray = np.diff(pickarray)/2
# Very artificial create color bounds
self.bounds = list(pickarray[:-1] + dpickarray)
self.bounds = [pickarray[0] - dpickarray[0]] + self.bounds
self.bounds = self.bounds + [pickarray[-1] + dpickarray[-1]]
# Create adhoc cmap and norm
self.cmap = plt.get_cmap('rainbow').copy()
self.cmap.set_bad('lightgray', alpha=0.0)
self.norm = BoundaryNorm(self.bounds, self.cmap.N)
# Get mask
self.__update_labeled_image__()
def start_labeling(self):
self.plot_figure()
self.segment_ax.set_title(
f"Picking {self.picklabels[self.activelabel]}")
plt.draw()
self.segment_figure.canvas.mpl_connect('key_press_event', self.onkey)
self.segment_figure.canvas.mpl_connect(
'button_press_event', self.onclick)
self.segment_figure.canvas.mpl_connect(
'button_release_event', self.onrelease)
self.segment_figure.canvas.mpl_connect(
'motion_notify_event', self.onmotion)
# Closing connections
self.cidsegment = self.segment_figure.canvas.mpl_connect(
'close_event', self.onclose_segment)
self.cidstatic = self.static_figure.canvas.mpl_connect(
'close_event', self.onclose_static)
plt.show(block=True)
return self.labeled
def onclick(self, event):
self.logger.debug(f'button press -- {event.button}')
self.clicktime = time.time()
self.mouse_pressed = True
self.coords = []
def onmotion(self, event):
# duration = time.time() - self.clicktime
# Second statement means nothing in toolbar selected
if self.mouse_pressed \
and self.segment_figure.canvas.manager.toolbar.mode == '':
# and np.isclose((duration * 100) % 5, 0):
# Get pixel locations
x, y = int(event.xdata), int(event.ydata)
self.logger.debug(f"Pick location: ({x}, {y})")
# Append coordinates
self.coords.append((x, y))
def onrelease(self, event):
self.mouse_pressed = False
self.logger.debug(f'button_release_event -- {event.button}')
# Return if pick not in axes
if event.inaxes != self.segment_ax:
return
# Return if pick is happening while toolbar active
state = self.segment_figure.canvas.manager.toolbar.mode
if state == '':
self.logger.debug(f"No figure tool selected.")
else:
self.logger.debug(f"{state} selected.")
return
duration = time.time() - self.clicktime
# Get pixel locations
x, y = int(event.xdata), int(event.ydata)
self.logger.debug(f"Pick location: ({x}, {y})")
# Label single coordinate
if event.button is MouseButton.LEFT and duration <= 0.3:
self.logger.debug("Left Mouse button pressed")
self.__label_segment__(x, y)
# Label many coordinates
elif event.button is MouseButton.LEFT and duration > 0.3:
self.logger.debug("Left Mouse button dragged")
for x, y in self.coords:
self.__update_labeled__(y, x)
self.__update_labeled_image__()
# Remove single coordinates
elif event.button is MouseButton.RIGHT and duration <= 0.3:
self.logger.debug("Right Mouse button pressed")
self.__remove_segment__(x, y)
# Remove many coordinates
elif event.button is MouseButton.RIGHT and duration > 0.3:
self.logger.debug("Right Mouse button dragged")
for x, y in self.coords:
self.__reset_segment__(x, y)
self.__update_labeled_image__()
# Remove duplicates
self.pickhistory[self.picklabels[self.activelabel]] = \
unique_arrays(
self.pickhistory[self.picklabels[self.activelabel]])
def __label_segment__(self, x, y):
# Update labeled and then update the image plotted on top
self.__update_labeled__(y, x)
self.__update_labeled_image__()
def __remove_segment__(self, x, y):
# Update labeled and the update image
self.__reset_segment__(x, y)
self.__update_labeled_image__()
def __reset_segment__(self, x, y):
# Get value of segmented image
val = self.segmentation[y, x]
# Find where everything is that is in the segment
pos = np.where(np.isclose(self.segmentation, val))
# Remove the entry from the history
for _i, _histpos in \
enumerate(self.pickhistory[self.picklabels[self.activelabel]]):
if np.array_equal(_histpos, pos):
self.pickhistory[self.picklabels[self.activelabel]].pop(_i)
break
# Reset the labeled array
self.labeled[pos] = self.labeldict['none']
def __update_labeled__(self, x, y):
# Get value of segmented image
val = self.segmentation[x, y]
# Find where everything is that is in the segment
pos = np.where(np.isclose(self.segmentation, val))
# Add pos to history
self.pickhistory[self.picklabels[self.activelabel]].append(pos)
# Set Labeled to active label number
self.labeled[pos] = self.picknumber[self.activelabel]
def onkey(self, event):
self.logger.debug(f'key press -- {event.key}')
if event.key == 'n':
self.__next_label__()
elif event.key == 'p':
self.__previous_label__()
elif event.key == 'd':
self.__remove_previous__()
elif event.key == 'escape':
self.__stop_labeling__()
def __next_label__(self):
self.activelabel += 1
if self.activelabel == self.nlabels:
self.activelabel -= 1
self.segment_ax.set_title(
f"Picking {self.picklabels[self.activelabel].capitalize()}")
self.segment_figure.canvas.draw()
def __previous_label__(self):
self.activelabel -= 1
if self.activelabel == -1:
self.activelabel += 1
self.segment_ax.set_title(
f"Picking {self.picklabels[self.activelabel]}")
self.segment_figure.canvas.draw()
def __remove_previous__(self):
self.logger.debug(
f"Remove previous selection for label -- "
f"{self.picklabels[self.activelabel]}")
if len(self.pickhistory[self.picklabels[self.activelabel]]) == 0:
return
# Get last position
pos = self.pickhistory[self.picklabels[self.activelabel]][-1]
# Reset labeled to the masked value
self.labeled[pos] = self.labeldict['none']
# Pop latest one
self.pickhistory[self.picklabels[self.activelabel]].pop(-1)
# Update the labeled image
self.__update_labeled_image__()
def __update_labeled_image__(self):
self.logger.debug(f'Updating Labeled Image ...')
self.labeled_m = np.ma.masked_values(
self.labeled, self.labeldict['none']
)
if hasattr(self, 'labeled_img'):
self.labeled_img.set_data(self.labeled_m)
else:
self.labeled_img = self.segment_ax.imshow(
self.labeled_m, cmap=self.cmap, norm=self.norm, aspect='auto',
alpha=0.5
)
self.segment_figure.canvas.draw()
def onclose_segment(self, event):
self.logger.debug("Closing Segmentation Figure")
self.static_figure.canvas.mpl_disconnect(self.cidstatic)
self.__stop_labeling__()
def onclose_static(self, event):
self.logger.debug("Closing Static Figure")
self.segment_figure.canvas.mpl_disconnect(self.cidsegment)
self.__stop_labeling__()
def __stop_labeling__(self):
self.logger.debug("Stop Labeling & quitting the program")
fignums = plt.get_fignums()
if len(fignums) != 0:
for num in fignums:
plt.close(num)
# Removed the plotted image so that it can be plotted again
if hasattr(self, 'labeled_img'):
del self.labeled_img
# Stop logging