Source code for imagedata.viewer

# import numbers
import copy
import logging
from typing import Union

import matplotlib.colors
import matplotlib.colorbar
import matplotlib.pyplot as plt
from matplotlib.offsetbox import AnchoredText
from matplotlib.widgets import PolygonSelector
# from matplotlib.widgets import Slider
from matplotlib.path import Path as MplPath
import mpl_toolkits.axes_grid1
import numpy as np
from datetime import datetime

from .series import Series

POSITIVE_EPS = 1e-3
logger = logging.getLogger(__name__)


[docs] class Viewer(object): """Viewer -- a graphical tool to display and interact with Series objects. Args: images (imagedata.Series or list): Series object or list of Series objects to view. fig (Figure): matplotlib.plt.figure if already exist (optional). ax (Axes): matplotlib axis if already exist (optional). follow (bool): Copy ROI to next tag. Default: False. colormap (str): Colour map for display. Default: Greys_r. norm (str or matplotlib.colors.Normalize): Normalization method. Either linear/log, or the `.Normalize` instance used to scale scalar data to the [0, 1] range before mapping to colors using colormap. colorbar (bool): Display colorbar with image. Default: None: determine colorbar based on colormap and norm. range before mapping to colors using colormap. window (number): Window width of signal intensities. Default: DICOM Window Width. level (number): Window level of signal intensities. Default: DICOM Window Center. link (bool): Whether scrolling is linked between displayed objects. Default: False. onselect (function): call function when roi changes. Default: None. When a polygon is completed or modified after completion, the *onselect* function is called and passed idx, tag and a list of the vertices as ``(xdata, ydata)`` tuples. """ def __init__(self, images, fig=None, ax=None, follow=False, colormap='Greys_r', norm='linear', colorbar=None, window=None, level=None, link=False, onselect=None): self.fig = fig self.ax = ax if self.ax is None: self.ax = default_layout(fig, len(images)) self.im = {} if isinstance(norm, str): if norm == 'linear': norm = matplotlib.colors.Normalize elif norm == 'log': norm = matplotlib.colors.LogNorm # elif norm == 'centered': # norm = matplotlib.colors.CenteredNorm else: raise ValueError('Unknown normalization function: {}'.format(norm)) if colorbar is None: colorbar = colormap != 'Greys_r' or\ (norm is not None and norm != matplotlib.colors.Normalize) for i, im in enumerate(images): self.im[i] = build_info(im, colormap, norm, colorbar, window, level) self.follow = follow self.link = link self.cidenter = None self.cidleave = None self.cidscroll = None self.callback_quit = None self.vertices = None # The polygon vertices, as a dictionary of tags of (x,y) self.poly = None self.paste_buffer = None self.callback_onselect = onselect self.viewport = {} self.viewport_idx = None self.set_default_viewport(self.ax) # Set wanted viewport self.update() # Update view to achieve wanted viewport def __repr__(self): return object.__repr__(self) def __str__(self): return "{" + "{0:s} images".format(len(self.im)) + "}"
[docs] def set_default_viewport(self, axes): """View as many Series as there are axes""" try: if len(axes.shape) == 2: rows, columns = axes.shape elif len(axes.shape) == 1: columns = axes.shape[0] rows = 1 else: raise ValueError('Cannot set default viewport') except AttributeError: rows = columns = 1 self.rows = rows self.columns = columns # Setup initial view self.viewport_set(0)
def update(self): # For each viewport for vp_idx in self.viewport: vp = self.viewport[vp_idx] if vp is None: # Clear ax continue if vp['next'] != vp['present']: # We want to show another image in this viewport vp['ax'].cla() if vp['next'] in self.im: vp['h'] = self.show(vp['ax'], self.im[vp['next']]) vp['present'] = vp['next'] else: raise IndexError("Series {} should be viewed, but does not exist".format( vp['next'] )) elif vp['next'] is None: vp['ax'].cla() vp['ax'].set_axis_off() # Update present image in viewport try: im = self.im[vp['present']] except KeyError: continue if not im['modified']: continue if im['tag_axis'] is not None: # 4D viewer vp['h'].set_data(im['im'][im['tag'], im['idx'], ...]) if im['slider'] is not None: im['slider'].valtext.set_text(pretty_tag_value(im)) elif im['slice_axis'] is not None: # 3D viewer vp['h'].set_data(im['im'][im['idx'], ...]) vp['h'].set_clim(vmin=im['vmin'], vmax=im['vmax']) # im['ax'].set_ylabel('Slice {}'.format(self.im['idx'])) # Lower right text if im['lower_right_text'] is not None and im['lower_right_data'] != (im['tag'],): fmt = '' if im['tag_axis'] is not None: fmt = '{}[{}]: {}'.format(im['input_order'], im['tag'], pretty_tag_value(im)) im['lower_right_text'].txt.set_text(fmt) im['lower_right_data'] = (im['tag'],) # Lower left text if im['color']: fmt = 'SL: {0:d}' if im['lower_left_text'] is not None and im['lower_left_data'] != im['idx']: im['lower_left_text'].txt.set_text(fmt.format(im['idx'])) im['lower_left_data'] = im['idx'] else: fmt, window, level = pretty_window_level(im) if im['lower_left_text'] is not None: if im['lower_left_data'] != (window, level, im['idx']): im['lower_left_text'].txt.set_text(fmt.format(im['idx'], window, level)) im['lower_left_data'] = (window, level, im['idx']) if self.callback_onselect is not None: try: vertices = self.vertices[im['idx']] except (KeyError, TypeError): vertices = None self.callback_onselect(im['idx'], im['tag'], vertices) try: vp['ax'].axes.figure.canvas.draw() except ValueError: pass im['modified'] = False def show(self, ax, im): if im is None: return None if im['slice_axis'] is None: # 2D viewer h = ax.imshow(im['im'], cmap=im['colormap'], norm=im['norm']) elif im['tag_axis'] is None: # 3D viewer h = ax.imshow(im['im'][im['idx'], ...], cmap=im['colormap'], norm=im['norm']) else: # 4D viewer h = ax.imshow(im['im'][im['tag'], im['idx'], ...], cmap=im['colormap'], norm=im['norm']) # Lower right text fmt = '{}[{}]: {}'.format(im['input_order'], im['tag'], pretty_tag_value(im)) im['lower_right_data'] = (im['tag'],) im['lower_right_text'] = AnchoredText(fmt, prop=dict(size=6, color='white', backgroundcolor='black'), frameon=False, loc='lower right' ) artist = ax.add_artist(im['lower_right_text']) artist.set_visible(im['show_text']) im['artists'].append(artist) # Update lower left text if im['color']: fmt = 'SL: {0:d}' im['lower_left_data'] = (im['idx']) im['lower_left_text'] = AnchoredText(fmt.format(im['idx']), prop=dict(size=6, color='white', backgroundcolor='black'), frameon=False, loc='lower left' ) else: fmt, window, level = pretty_window_level(im) im['lower_left_data'] = (window, level, im['idx']) im['lower_left_text'] = AnchoredText(fmt.format(im['idx'], window, level), prop=dict(size=6, color='white', backgroundcolor='black'), frameon=False, loc='lower left' ) artist = ax.add_artist(im['lower_left_text']) artist.set_visible(im['show_text']) im['artists'].append(artist) # Update upper left text fmt = self.upper_left_text(im['im']) im['upper_left_text'] = AnchoredText(fmt, prop=dict(size=6, color='white', backgroundcolor='black'), frameon=False, loc='upper left' ) artist = ax.add_artist(im['upper_left_text']) artist.set_visible(im['show_text']) im['artists'].append(artist) # Update upper right text fmt = self.upper_right_text(im['im']) im['upper_right_text'] = AnchoredText(fmt, prop=dict(size=6, color='white', backgroundcolor='black'), frameon=False, loc='upper right' ) artist = ax.add_artist(im['upper_right_text']) artist.set_visible(im['show_text']) im['artists'].append(artist) im['modified'] = True if im['colorbar']: divider = mpl_toolkits.axes_grid1.make_axes_locatable(ax) cax = divider.append_axes('right', size='5%', pad=0.05) # Create fake pcolormesh to create colorbar matching im['colorbar'] and im['norm'] _ = plt.colorbar( cax.pcolormesh( np.array([[im['norm'].vmin, im['norm'].vmax]]), visible=False, cmap=im['colorbar'], vmin=im['norm'].vmin, vmax=im['norm'].vmax ), label=im['colormap_label'], cax=cax ) ax.set_axis_off() # if im['slices'] == im2['slices']: # plt.subplots_adjust(bottom=0.1) # self.rax = plt.axes([0.0, 0.0, 0.2, 0.1], frame_on=False) # self.linkbutton = CheckButtons(self.rax, ['Link'], [link]) # self.linkclicked = self.linkbutton.on_clicked(self.toggle_button) return h def pretty_datetime(self, my_date, my_time): _date = _time = None if my_date is not None: try: _date = datetime.strptime(my_date, '%Y%m%d') except ValueError: pass if my_time is not None: try: _time = datetime.strptime(my_time, '%H%M%S.%f') except ValueError: try: _time = datetime.strptime(my_time, '%H%M%S') except ValueError: pass _date_fmt = '' if _date is not None or _time is not None: if _date is not None: _date_fmt = '{} '.format(_date.strftime("%Y-%m-%d")) else: logger.debug('Cannot add date for \"{}\"'.format(my_date)) if _time is not None: _date_fmt += _time.strftime("%H:%M:%S") else: logger.debug('Cannot add time for \"{}\"'.format(my_time)) return _date_fmt
[docs] def upper_left_text(self, im): """Update upper left text """ data = {} for attr in ['patientName', 'patientID']: try: data[attr] = getattr(im, attr, '') except ValueError: data[attr] = '' for attr in ['StudyDate', 'StudyTime']: try: data[attr] = im.getDicomAttribute(attr) except Exception: data[attr] = '' _date_fmt = self.pretty_datetime(data['StudyDate'], data['StudyTime']) fmt = '' if data['patientName']: pat_nam: str = '{}'.format(data['patientName']) while pat_nam[-1] == '^': pat_nam = pat_nam[:-1] pat_nam = pat_nam.replace('^', ', ') fmt = pat_nam if data['patientID']: fmt += '\n{}'.format(data['patientID']) if len(_date_fmt) > 0: fmt += '\n{}'.format(_date_fmt) return fmt
[docs] def upper_right_text(self, im): """Update upper right text """ data = {} for attr in ['seriesNumber', 'seriesDescription']: try: data[attr] = getattr(im, attr, '') except ValueError: data[attr] = '' for attr in ['SeriesDate', 'SeriesTime']: try: data[attr] = im.getDicomAttribute(attr) except Exception: data[attr] = '' _date_fmt = self.pretty_datetime(data['SeriesDate'], data['SeriesTime']) fmt = '' if data['seriesNumber']: fmt = '{}. '.format(data['seriesNumber']) fmt += data['seriesDescription'] if len(_date_fmt) > 0: fmt += '\n{}'.format(_date_fmt) return fmt
def connect_draw(self, roi=None, color='w', callback_quit=None): self.poly_color = color self.callback_quit = callback_quit idx = self.im[0]['idx'] if roi is None: self.poly = {} self.vertices = {} if self.follow: self.poly[0, idx] = MyPolygonSelector(self.ax[0, 0], self.onselect, lineprops={'color': self.poly_color}, tag=(0, idx), copy=self.on_copy, paste=self.on_paste) else: self.poly[idx] = MyPolygonSelector(self.ax[0, 0], self.onselect, lineprops={'color': self.poly_color}, tag=idx, copy=self.on_copy, paste=self.on_paste) else: self.poly = {} self.vertices = roi if self.follow: for tag in range(self.im[0]['tags']): for i in range(self.im[0]['slices']): vertices = copy.copy(self.vertices[tag, i]) \ if (tag, i) in self.vertices else None self.poly[tag, i] = MyPolygonSelector(self.ax[0, 0], self.onselect, lineprops={ 'color': self.poly_color}, vertices=vertices, tag=(tag, i), copy=self.on_copy, paste=self.on_paste) # Polygon on single slice and tag 0, only if i == idx and tag == 0: assert self.poly[tag, i].tag == (tag, i), \ "Tag index mismatch {}!={}".format((tag, i), self.poly[tag, i].tag) self.poly[tag, i].connect_default_events() self.poly[tag, i].set_visible(True) self.poly[tag, i].update() else: assert self.poly[tag, i].tag == (tag, i), \ "Tag index mismatch {}!={}".format((tag, i), self.poly[tag, i].tag) self.poly[tag, i].disconnect_events() self.poly[tag, i].set_visible(False) self.poly[tag, i].update() else: for i in range(self.im[0]['slices']): vertices = copy.copy(self.vertices[i]) if i in self.vertices else None self.poly[i] = MyPolygonSelector(self.ax[0, 0], self.onselect, lineprops={'color': self.poly_color}, vertices=vertices, tag=i, copy=self.on_copy, paste=self.on_paste) # Polygon on single slice only if i != idx: assert self.poly[i].tag == i, \ "Tag index mismatch {}!={}".format(i, self.poly[i].tag) self.poly[i].disconnect_events() self.poly[i].set_visible(False) self.poly[i].update() self.cidscroll = self.fig.canvas.mpl_connect('scroll_event', self.scroll) self.cidkeypress = self.fig.canvas.mpl_connect('key_press_event', self.key_press) def disconnect_draw(self): if self.follow: for t in range(self.im[0]['tags']): for idx in range(self.im[0]['slices']): if (t, idx) in self.poly and self.poly[t, idx] is not None: self.poly[t, idx].disconnect_events() else: for idx in range(self.im[0]['slices']): if idx in self.poly and self.poly[idx] is not None: self.poly[idx].disconnect_events() self.fig.canvas.mpl_disconnect(self.scroll) self.fig.canvas.mpl_disconnect(self.cidkeypress) def connect(self): # Connect to all the events we need # self.cidenter = self.fig.canvas.mpl_connect('axes_enter_event', self.enter_axes) # self.cidleave = self.fig.canvas.mpl_connect('axes_leave_event', self.leave_axes) self.cidscroll = self.fig.canvas.mpl_connect('scroll_event', self.scroll) self.cidkeypress = self.fig.canvas.mpl_connect('key_press_event', self.key_press) self.cidpress = self.fig.canvas.mpl_connect('button_press_event', self.on_press) self.cidrelease = self.fig.canvas.mpl_connect('button_release_event', self.on_release) self.cidmotion = self.fig.canvas.mpl_connect('motion_notify_event', self.on_motion) def disconnect(self): self.fig.canvas.mpl_disconnect(self.scroll) self.fig.canvas.mpl_disconnect(self.cidkeypress) self.fig.canvas.mpl_disconnect(self.cidpress) self.fig.canvas.mpl_disconnect(self.cidrelease) self.fig.canvas.mpl_disconnect(self.cidmotion) def onselect(self, vertices): idx = self.im[0]['idx'] tag = None if self.follow: tag = self.im[0]['tag'] self.vertices[tag, idx] = copy.copy(vertices) else: self.vertices[idx] = copy.copy(vertices) if self.callback_onselect is not None: self.callback_onselect(idx, tag, vertices) def on_copy(self, polygon): self.paste_buffer = polygon def on_paste(self): return self.paste_buffer # def grid_from_roi(self): # """Return drawn ROI as grid. # # Returns: # Numpy ndarray with shape (nz,ny,nx) from original image, dtype ubyte. # Voxels inside ROI is 1, 0 outside. # """ # nt, nz, ny, nx = self.im[0]['tags'], self.im[0]['slices'], self.im[0]['rows'], # self.im[0]['columns'] # if self.follow: # grid = np.zeros((nt, nz, ny, nx), dtype=np.ubyte) # for idx in range(nz): # last_used_tag = None # for t in range(nt): # tag = t, idx # if tag not in self.vertices or self.vertices[tag] is None: # if last_used_tag is None: # # Most probably a slice with no ROIs # continue # # Propagate last drawn ROI to unfilled tags # self.vertices[tag] = self.vertices[last_used_tag] # else: # last_used_tag = tag # path = MplPath(self.vertices[tag]) # x, y = np.meshgrid(np.arange(nx), np.arange(ny)) # x, y = x.flatten(), y.flatten() # points = np.vstack((x, y)).T # grid[t, idx] = path.contains_points(points).reshape((ny, nx)) # else: # grid = np.zeros((nz, ny, nx), dtype=np.ubyte) # for idx in range(nz): # if idx not in self.vertices or self.vertices[idx] is None: # continue # path = MplPath(self.vertices[idx]) # x, y = np.meshgrid(np.arange(nx), np.arange(ny)) # x, y = x.flatten(), y.flatten() # points = np.vstack((x, y)).T # grid[idx] = path.contains_points(points).reshape((ny, nx)) # return grid
[docs] def get_roi(self): """Return drawn ROI. Returns: Dict of slices, index as [tag,slice] or [slice], each is list of (x,y) pairs. """ vertices = {} for tag in self.poly.keys(): if len(self.poly[tag].verts) > 0: vertices[tag] = self.poly[tag].verts return vertices
# def enter_axes(self, event): # if event.inaxes == self.im['ax']: # print('enter_axes', self.im['ax'], event.inaxes) # elif self.im2 is not None and event.inaxes == self.im2['ax']: # print('enter_axes2', self.im2['ax'], event.inaxes) # def leave_axes(self, event): # print('leave_axes', event.inaxes) def key_press(self, event): if event.key == 'up': self.scroll_data(event.inaxes, 1) elif event.key == 'down': self.scroll_data(event.inaxes, -1) elif event.key == 'left': self.advance_data(event.inaxes, -1) elif event.key == 'right': self.advance_data(event.inaxes, 1) elif event.key == 'pageup': self.viewport_advance(-self.rows * self.columns) elif event.key == 'pagedown': self.viewport_advance(self.rows * self.columns) elif event.key == 'ctrl+home': self.viewport_set(0) elif event.key == 'ctrl+end': self.viewport_set(len(self.im) - self.rows * self.columns) elif event.key == 'ctrl+left': self.viewport_advance(-1) elif event.key == 'ctrl+right': self.viewport_advance(1) elif event.key == 'ctrl+up': self.viewport_advance(-self.columns) elif event.key == 'ctrl+down': self.viewport_advance(self.columns) elif event.key == 'H' or event.key == 'h': # Hide display self.toggle_hide(event.inaxes) elif event.key == 'W' or event.key == 'w': # Normalize window center/width using a probability histogram self.normalize_window(event.inaxes) elif event.key == 'Q' or event.key == 'q': # Quit Viewer # Set present window/level on Series objects for i in self.im.keys(): self.im[i]['im'].windowCenter = self.im[i]['level'] self.im[i]['im'].windowWidth = self.im[i]['window'] if self.callback_quit is not None: self.callback_quit() # else: # print('key_press: {}'.format(event.key)) def scroll(self, event): if event.button == 'up': self.scroll_data(event.inaxes, 1) elif event.button == 'down': self.scroll_data(event.inaxes, -1) def find_viewport_from_event(self, inaxes): for vp_idx in self.viewport: vp = self.viewport[vp_idx] if vp is not None: ax = vp['ax'] if inaxes == ax: return vp_idx # Do nothing when the event does not match with any viewport axes return None def find_image_from_event(self, inaxes): for vp_idx in self.viewport: vp = self.viewport[vp_idx] if vp is not None: ax = vp['ax'] if inaxes == ax: im_idx = vp['present'] return self.im[im_idx] # Do nothing when the event does not match with any viewport axes return None def scroll_data(self, inaxes, increment): im = self.find_image_from_event(inaxes) if im is None: return old_idx = im['idx'] im['idx'] = min(max(im['idx'] + increment, 0), im['slices'] - 1) if self.link: # Scroll all images to same index (if possible) self.scroll_all_data(im, im['idx']) im['modified'] = old_idx != im['idx'] if self.poly is not None: new_idx = im['idx'] if self.follow: old_idx = im['tag'], old_idx new_idx = im['tag'], new_idx if old_idx in self.poly: assert self.poly[old_idx].tag == old_idx, \ "Tag index mismatch {}!={}".format(old_idx, self.poly[old_idx].tag) if new_idx in self.poly: assert self.poly[new_idx].tag == new_idx, \ "Tag index mismatch {}!={}".format(new_idx, self.poly[new_idx].tag) if im['modified']: self.poly[old_idx].disconnect_events() self.poly[old_idx].set_visible(False) self.poly[old_idx].update() if new_idx in self.poly and self.poly[new_idx] is not None: self.poly[new_idx].connect_default_events() self.poly[new_idx].set_visible(True) self.poly[new_idx].update() else: self.poly[new_idx] = MyPolygonSelector(self.ax[0, 0], self.onselect, lineprops={'color': self.poly_color}, tag=new_idx, copy=self.on_copy, paste=self.on_paste) # if self.link and self.im['scrollable'] and self.im2['scrollable']: # self.im['idx'] = min(max(self.im['idx'] + increment, 0), self.im['slices']-1) # self.im2['idx'] = self.im['idx'] # elif inaxes == self.im['ax'] and self.im['scrollable']: # self.im['idx'] = min(max(self.im['idx'] + increment, 0), self.im['slices']-1) # elif self.im2 is not None and inaxes == self.im2['ax'] and self.im2['scrollable']: # self.im2['idx'] = min(max(self.im2['idx'] + increment, 0), self.im2['slices']-1) self.update() def scroll_all_data(self, im, idx): for vp_idx in self.viewport: vp = self.viewport[vp_idx] if vp is not None: im_idx = vp['present'] if im_idx != idx: im2 = self.im[im_idx] old_idx = im2['idx'] im2['idx'] = min(max(idx, 0), im2['slices'] - 1) im2['modified'] = old_idx != im2['idx']
[docs] def advance_data(self, inaxes, increment): """Advance display to next/previous tag value""" im = self.find_image_from_event(inaxes) if im is None or im['tag_axis'] is None: return old_tag = im['tag'] im['tag'] = min(max(im['tag'] + increment, 0), len(im['tag_axis']) - 1) im['modified'] = old_tag != im['tag'] if self.poly is not None and self.follow and im['modified']: new_tag = im['tag'] idx = im['idx'] assert self.poly[old_tag, idx].tag == (old_tag, idx), \ "Tag index mismatch {}!={}".format((old_tag, idx), self.poly[old_tag, idx].tag) if (new_tag, idx) not in self.poly and (old_tag, idx) in self.poly and \ self.poly[old_tag, idx] is not None: # Copy the polygon to next tag when there is none self.poly[new_tag, idx] = MyPolygonSelector(self.ax[0, 0], self.onselect, lineprops={'color': self.poly_color}, vertices=self.poly[old_tag, idx].verts, tag=(new_tag, idx), copy=self.on_copy, paste=self.on_paste) assert self.poly[old_tag, idx].tag == (old_tag, idx), \ "Tag index mismatch {}!={}".format((old_tag, idx), self.poly[old_tag, idx].tag) self.poly[old_tag, idx].disconnect_events() self.poly[old_tag, idx].set_visible(False) self.poly[old_tag, idx].update() assert self.poly[new_tag, idx].tag == (new_tag, idx), \ "Tag index mismatch {}!={}".format((new_tag, idx), self.poly[new_tag, idx].tag) self.poly[new_tag, idx].connect_default_events() self.poly[new_tag, idx].set_visible(True) self.poly[new_tag, idx].update() # if self.link and self.im['scrollable'] and self.im2['scrollable']: # self.im['idx'] = min(max(self.im['idx'] + increment, 0), self.im['slices']-1) # self.im2['idx'] = self.im['idx'] # elif inaxes == self.im['ax'] and self.im['scrollable']: # self.im['idx'] = min(max(self.im['idx'] + increment, 0), self.im['slices']-1) # elif self.im2 is not None and inaxes == self.im2['ax'] and self.im2['scrollable']: # self.im2['idx'] = min(max(self.im2['idx'] + increment, 0), self.im2['slices']-1) self.update()
[docs] def viewport_advance(self, increment): """Advance viewport by given increment """ self.viewport_set(self.viewport_idx + increment)
[docs] def viewport_set(self, position): """Set viewport to image position """ images = len(self.im) # Position must be in range 0:images-(rows*columns) vp_idx = min(position, images - self.rows * self.columns - 1) vp_idx = max(vp_idx, 0) if vp_idx == self.viewport_idx: # No change return # print('viewport_set: old idx {}, new idx {}'.format(self.viewport_idx, position)) self.viewport_idx = vp_idx new_viewport = {} for row in range(self.rows): for column in range(self.columns): if vp_idx in self.im: new_viewport[vp_idx] = { 'ax': self.ax[row, column], 'present': None, 'next': vp_idx, 'h': None } else: new_viewport[vp_idx] = None self.ax[row, column].set_axis_off() vp_idx += 1 self.viewport = new_viewport self.update()
[docs] def toggle_hide(self, inaxes): """Toggle the display of text on images """ # im = self.find_image_from_event(inaxes) # if im is None: # return for im in self.im: self.im[im]['show_text'] = not self.im[im]['show_text'] for artist in self.im[im]['artists']: artist.set_visible(self.im[im]['show_text']) self.im[im]['modified'] = True self.update()
def normalize_window(self, inaxes): im = self.find_image_from_event(inaxes) if im is None: return # Normalize on displayed slice only probs = (0.01, 0.99) if im['slice_axis'] is None: # 2D data vmin, vmax = im['im'].calculate_clip_range(probs) elif im['tag_axis'] is None: # 3D data idx = im['idx'] vmin, vmax = im['im'][idx].calculate_clip_range(probs) else: # 4D data idx = im['idx'] tag = im['tag'] vmin, vmax = im['im'][tag, idx].calculate_clip_range(probs) im['vmin'] = vmin im['vmax'] = vmax level = (np.float32(vmax) + np.float32(vmin)) / 2 if np.isnan(level): level = 1 # if abs(level) > 2: # level = round(level) im['level'] = level window = vmax - vmin if np.isnan(window): window = 1 # if abs(window) > 2: # window = round(window) im['window'] = window im['modified'] = True self.update() def update_tag(self, value): # value = int(round(self.im['slider'].val)) for key, im in self.im.items(): if im is not None and im['slider'] is not None: inc = 1 / len(im['im'].tags[0]) # Increment per tag step tag_index = int(round(int(value / inc) * inc)) # Tag index im['tag'] = tag_index im['slider'].valtext.set_text(pretty_tag_value(im)) im['modified'] = True self.update() def toggle_button(self, button): if button == 'Link': self.link = self.linkbutton.get_status()[0] # Link button is button 0 if self.link: # Display same slice for both images self.im2['idx'] = self.im['idx'] self.update() def on_press(self, event): # Button press - determine action if event.button == 1 and not event.dblclick: self.start_window_level(event) def on_release(self, event): # Button release - determine action if event.button == 1: self.end_window_level(event) def on_motion(self, event): # Motion - determine action if event.button == 1: self.modify_window_level(event) def start_window_level(self, event): # On button press we will see of the mouse is over us and store some data im = self.find_image_from_event(event.inaxes) if im is not None: im['press'] = event.xdata, event.ydata def end_window_level(self, event): # On button release im = self.find_image_from_event(event.inaxes) if im is not None: im['press'] = None def modify_window_level(self, event): # On motion, modify window and level, and update display im = self.find_image_from_event(event.inaxes) if im is not None and im['press'] is not None: # delta = (im['vmax'] - im['vmin']) / 100 delta = (im['im'].max() - im['im'].min()) / 100 dx = delta * (event.xdata - im['press'][0]) dy = delta * (im['press'][1] - event.ydata) im['press'] = event.xdata, event.ydata im['window'] = max(POSITIVE_EPS, im['window'] + dy) assert im['window'] >= 0, "Window must be non-negative." im['level'] = im['level'] + dx im['vmin'] = im['level'] - im['window'] / 2 im['vmax'] = im['level'] + im['window'] / 2 im['vmin'], im['vmax'] = _check_vmin_vmax(im['vmin'], im['vmax'], im['norm']) im['modified'] = True self.update()
class MyPolygonSelector(PolygonSelector): """Select a polygon region of an axes. Place vertices with each mouse click, and make the selection by completing the polygon (clicking on the first vertex). Hold the *ctrl* key and click and drag a vertex to reposition it (the *ctrl* key is not necessary if the polygon has already been completed). Hold the *shift* key and click and drag anywhere in the axes to move all vertices. Press the *esc* key to start a new polygon. For the selector to remain responsive you must keep a reference to it. Class MyPolygonSelector subclasses matplotlib.widgets.PolygonSelector. Allows to set an initial polygon. """ def __init__(self, ax, onselect, useblit=False, lineprops=None, markerprops=None, vertex_select_radius=10, vertices=None, tag=None, copy=None, paste=None): super().__init__(ax, onselect, useblit=useblit, props=lineprops, handle_props=markerprops, grab_range=vertex_select_radius) self.tag = tag self.copy_handle = copy self.paste_handle = paste self._polygon_completed = False if lineprops is None: self.lineprops = dict(color='k', linestyle='-', linewidth=2, alpha=0.5) else: self.lineprops = lineprops self.lineprops['animated'] = self.useblit if markerprops is None: self.markerprops = dict(markeredgecolor='k', markerfacecolor=self.lineprops.get('color', 'k')) else: self.markerprops = markerprops self.vertex_select_radius = vertex_select_radius if vertices is not None and len(vertices): self.verts = vertices def _on_key_release(self, event): """Key release event handler.""" # Add back the pending vertex if leaving the 'move_vertex' or # 'move_all' mode (by checking the released key) if (not self._selection_completed and (event.key == self._state_modifier_keys.get('move_vertex') or event.key == self._state_modifier_keys.get('move_all'))): self._xys.append((event.xdata, event.ydata)) self._draw_polygon() # Reset the polygon if the released key is the 'clear' key. elif event.key == self._state_modifier_keys.get('clear'): event = self._clean_event(event) self._xys = [(event.xdata, event.ydata)] self._selection_completed = False self._remove_box() self.set_visible(True) # Copy polygon to paste buffer using handle elif event.key.upper() == 'C': if self.copy_handle is not None: self.copy_handle(self) # Add polygon from paste buffer handle elif event.key.upper() == 'V': if self.paste_handle is not None: obj = self.paste_handle() self.verts = obj.verts self._selection_completed = obj._selection_completed self._draw_polygon() self.set_visible(True) if self._selection_completed: self.onselect(self.verts) def default_layout(fig, n): """Setup a default layout for given number of axes. Args: fig: matplotlib figure n: Number of axes required Returns: List of Axes Raises: ValueError: When no ax axes or > 9*9 are required. When no figure is given. """ if fig is None: raise ValueError("No Figure given") if n < 1: raise ValueError("No layout when no axes are required") for rows in range(1, 5): if rows * rows >= n: return fig.subplots(rows, rows, squeeze=False) # columns = rows if rows * (rows + 1) >= n: return fig.subplots(rows, rows + 1, squeeze=False) # columns = rows+1 raise ValueError("Too many axes required (n={})".format(n)) def grid_from_roi(im: Series, vertices: dict, single: bool = False) -> Union[bool, Series]: """Return drawn ROI as grid. Args: im (imagedata.Series): Series object as template vertices: The polygon vertices, as a dictionary of tags of (x,y) single (bool): Draw ROI in single slice per tag Returns: Series with shape (nz,ny,nx) from original image, dtype ubyte. Voxels inside ROI is 1, 0 outside. """ def _roi_in_any_slice(tag): """Check whether there is a ROI in any slice""" t, i = tag for idx in range(im.slices): # if (t, idx) in vertices and vertices[t, idx] is None: # print('Check {} None'.format((t, idx))) # elif (t, idx) in vertices: # print('Check {} {}'.format((t, idx), len(vertices[t, idx]))) # else: # print('Check {} not found'.format((t, idx))) if (t, idx) in vertices and vertices[t, idx] is not None: return True return False keys = list(vertices.keys())[0] # print('Viewer.grid_from_roi: keys: {}'.format(keys)) # print('Viewer.grid_from_roi: vertices: {}'.format(vertices)) follow = issubclass(type(keys), tuple) nt, nz, ny, nx = len(im.tags[0]), im.slices, im.rows, im.columns input_order = im.input_order if follow and not single: grid = np.zeros_like(im, dtype=np.ubyte) skipped = [] copied = [] for idx in range(nz): last_used_tag = None for t in range(nt): tag = t, idx if tag not in vertices or vertices[tag] is None: if last_used_tag is None: # Most probably a slice with no ROIs skipped.append(tag) continue elif _roi_in_any_slice(tag): # print('Found in some slice for', tag) skipped.append(tag) continue # Propagate last drawn ROI to unfilled tags vertices[tag] = copy.copy(vertices[last_used_tag]) copied.append((tag, last_used_tag)) else: last_used_tag = tag path = MplPath(vertices[tag]) x, y = np.meshgrid(np.arange(nx), np.arange(ny)) x, y = x.flatten(), y.flatten() points = np.vstack((x, y)).T grid[t, idx] = path.contains_points(points).reshape((ny, nx)) elif follow and single: grid = np.zeros_like(im, dtype=np.ubyte) skipped = [] copied = [] last_used_tag = None for t in range(nt): for idx in range(nz): tag = t, idx if tag not in vertices or vertices[tag] is None: if last_used_tag is None: # Most probably a slice with no ROIs skipped.append(tag) continue elif last_used_tag[1] != idx: continue elif _roi_in_any_slice(tag): # print('Found in some slice for', tag) skipped.append(tag) continue # Propagate last drawn ROI to unfilled tags vertices[tag] = copy.copy(vertices[last_used_tag]) copied.append((tag, last_used_tag)) else: last_used_tag = tag path = MplPath(vertices[tag]) x, y = np.meshgrid(np.arange(nx), np.arange(ny)) x, y = x.flatten(), y.flatten() points = np.vstack((x, y)).T grid[t, idx] = path.contains_points(points).reshape((ny, nx)) # if len(skipped) > 0: # print('Skipped: {}'.format(skipped)) # if len(copied) > 0: # print('Copied: {}'.format(copied)) else: grid = np.zeros((nz, ny, nx), dtype=np.ubyte) for idx in range(nz): if idx not in vertices or vertices[idx] is None: continue path = MplPath(vertices[idx]) x, y = np.meshgrid(np.arange(nx), np.arange(ny)) x, y = x.flatten(), y.flatten() points = np.vstack((x, y)).T grid[idx] = path.contains_points(points).reshape((ny, nx)) input_order = 'none' if im.ndim == 2: grid = grid.reshape((ny, nx)) return Series(grid, input_order=input_order, template=im, geometry=im) def get_level(si, level): if level is None: # First, attempt to get DICOM attribute try: level = si.windowCenter except (KeyError, AttributeError, TypeError): pass try: if len(level) > 1: level = level[0] except TypeError: pass if level is None: level = (np.float32(np.nanmax(si)) + np.float32(np.nanmin(si))) / 2 if np.isnan(level): level = 1 if abs(level) > 2: level = round(level) return level def get_window_level(si, norm, window, level): if window is None: # First, attempt to get DICOM attribute try: window = si.windowWidth except (KeyError, AttributeError, TypeError): pass try: if len(window) > 1: window = window[0] # print('Viewer.get_window_level: {} len {}'.format(window, len(window))) except TypeError: pass if window is None: window = np.float32(np.nanmax(si)) - np.float32(np.nanmin(si)) if np.isnan(window): window = 1 if abs(window) > 2: window = round(window) level = get_level(si, level) vmin, vmax = _check_vmin_vmax(level - window / 2, level + window / 2, norm) return window, level, vmin, vmax def _check_vmin_vmax(vmin, vmax, norm): if type(norm) is type: norm = norm(vmin=vmin, vmax=vmax) if type(norm) is matplotlib.colors.LogNorm: vmin = max(POSITIVE_EPS, vmin) vmax = max(POSITIVE_EPS, vmin, vmax) return vmin, vmax def build_info(im, colormap, norm, colorbar, window, level): if im is None: return None if not issubclass(type(im), Series): raise ValueError('Cannot display image of type {}'.format(type(im))) # im might be modified below (color version), hence, save present color presentation im_color = im.color im_colormap = copy.copy(im.colormap) im_colormap_norm = copy.copy(im.colormap_norm) try: im_colormap_label = im.colormap_label except ValueError: im_colormap_label = None if colormap is None: colormap = 'Greys_r' try: if colormap == 'Greys_r' and im.photometricInterpretation == 'MONOCHROME1': colormap = 'Greys' except ValueError: pass if np.issubdtype(im.dtype, np.floating): lut = 256 elif np.issubdtype(im.dtype, np.complexfloating): lut = 256 logger.warning('Displaying real part of complex values.') im = np.real(im) elif im.color: lut = 256 im = im.view(dtype=np.uint8).reshape(im.shape + (3,)) else: lut = (np.nanmax(im).item()) + 1 if im_colormap is None: if not issubclass(type(colormap), matplotlib.colors.Colormap): colormap = plt.get_cmap(colormap, lut) colormap.set_bad(color='k') # Important for log display of non-positive values colormap.set_under(color='k') colormap.set_over(color='w') else: colormap = im_colormap window, level, vmin, vmax = get_window_level(im, norm, window, level) if im_colormap_norm is None: if type(norm) is type: norm = norm(vmin=vmin, vmax=vmax) if im_color: norm = None else: norm = im_colormap_norm tag_axis = im.get_tag_axis() slice_axis = im.get_slice_axis() return { 'im': im, # Image Series instance 'input_order': im.input_order, 'color': im_color, 'modified': True, # update() 'show_text': True, # Show text on display 'artists': [], # List of artists 'slider': None, # 4D slider 'lower_left_text': None, # AnchoredText object 'lower_left_data': None, # Tuple of present data 'lower_right_text': None, # AnchoredText object 'lower_right_data': None, # Tuple of present data 'scrollable': im.slices > 1, # Can we scroll the instance? 'taggable': tag_axis is not None, # Can we slide through tags? 'tags': len(im.tags[0]), # Number of tags 'slices': im.slices, # Number of slices 'rows': im.rows, # Number of rows 'columns': im.columns, # Number of columns 'tag': 0, # Displayed tag index 'idx': im.slices // 2, # Displayed slice index 'tag_axis': tag_axis, # Axis instance of im 'slice_axis': slice_axis, # Axis instance of im 'colormap': colormap, # Colour map 'colormap_label': im_colormap_label, # Colour map label 'norm': im_colormap_norm if im_color else norm, # Normalization function 'colorbar': im_colormap, # Display colorbar unless RGB image 'window': window, # Window center 'level': level, # Window level 'vmin': vmin, # Lower window value 'vmax': vmax # Upper window value } def pretty_tag_value(im): tag = im['tag'] if im['input_order'] == 'time': return '{0:0.2f}s'.format(im['im'].timeline[tag]) elif im['input_order'] == 'b': return '{}'.format(int(im['im'].tags[0][tag])) elif im['input_order'] == 'te': return '{}ms'.format(int(im['im'].tags[0][tag])) elif im['input_order'] == 'fa': return '{}'.format(im['im'].tags[0][tag]) else: return '{}'.format(im['im'].tags[0][tag]) def pretty_window_level(im): si, window, level = im['im'], im['window'], im['level'] if si.dtype.kind in ('i', 'u'): fmt = 'SL: {0:d}\nW: {1:d} C: {2:d}' window = int(window) level = int(level) else: fmt = 'SL: {0:d}\nW: {1:.2f} C: {2:.2f}' window = np.around(window, 2) level = np.around(level, 2) return fmt, window, level