class OrthoSlicer3D(object):
def __init__(self, data, affine=None, axes=None, title=None):
"""
Parameters
----------
data : array-like
The data that will be displayed by the slicer. Should have 3+
dimensions.
affine : array-like or None, optional
Affine transform for the data. This is used to determine
how the data should be sliced for plotting into the sagittal,
coronal, and axial view axes. If None, identity is assumed.
The aspect ratio of the data are inferred from the affine
transform.
axes : tuple of mpl.Axes or None, optional
3 or 4 axes instances for the 3 slices plus volumes,
or None (default).
title : str or None, optional
The title to display. Can be None (default) to display no
title.
"""
# Use these late imports of matplotlib so that we have some hope that
# the test functions are the first to set the matplotlib backend. The
# tests set the backend to something that doesn't require a display.
self._plt = plt = optional_package('matplotlib.pyplot')[0]
mpl_patch = optional_package('matplotlib.patches')[0]
self._title = title
self._closed = False
data = np.asanyarray(data)
if data.ndim < 3:
raise ValueError('data must have at least 3 dimensions')
if np.iscomplexobj(data):
raise TypeError("Complex data not supported")
affine = np.array(affine, float) if affine is not None else np.eye(4)
if affine.shape != (4, 4):
raise ValueError('affine must be a 4x4 matrix')
# determine our orientation
self._affine = affine
codes = axcodes2ornt(aff2axcodes(self._affine))
self._order = np.argsort([c[0] for c in codes])
self._flips = np.array([c[1] < 0 for c in codes])[self._order]
self._flips = list(self._flips) + [False] # add volume dim
self._scalers = voxel_sizes(self._affine)
self._inv_affine = np.linalg.inv(affine)
# current volume info
self._volume_dims = data.shape[3:]
self._current_vol_data = data[:, :, :, 0] if data.ndim > 3 else data
self._data = data
self._clim = np.percentile(data, (1., 99.))
del data
if axes is None: # make the axes
# ^ +---------+ ^ +---------+
# | | | | | |
# | Sag | | Cor |
# S | 0 | S | 1 |
# | | | |
# | | | |
# +---------+ +---------+
# A --> <-- R
# ^ +---------+ +---------+
# | | | | |
# | Axial | | Vol |
# A | 2 | | 3 |
# | | | |
# | | | |
# +---------+ +---------+
# <-- R <-- t -->
fig, axes = plt.subplots(2, 2)
fig.set_size_inches((8, 8), forward=True)
self._axes = [axes[0, 0], axes[0, 1], axes[1, 0], axes[1, 1]]
plt.tight_layout(pad=0.1)
if self.n_volumes <= 1:
fig.delaxes(self._axes[3])
self._axes.pop(-1)
if self._title is not None:
fig.canvas.set_window_title(str(title))
else:
self._axes = [axes[0], axes[1], axes[2]]
if len(axes) > 3:
self._axes.append(axes[3])
# Start midway through each axis, idx is current slice number
self._ims, self._data_idx = list(), list()
# set up axis crosshairs
self._crosshairs = [None] * 3
r = [self._scalers[self._order[2]] / self._scalers[self._order[1]],
self._scalers[self._order[2]] / self._scalers[self._order[0]],
self._scalers[self._order[1]] / self._scalers[self._order[0]]]
self._sizes = [self._data.shape[order] for order in self._order]
for ii, xax, yax, ratio, label in zip([0, 1, 2], [1, 0, 0], [2, 2, 1],
r, ('SAIP', 'SRIL', 'ARPL')):
ax = self._axes[ii]
d = np.zeros((self._sizes[yax], self._sizes[xax]))
im = self._axes[ii].imshow(
d, vmin=self._clim[0], vmax=self._clim[1], aspect=1,
cmap='gray', interpolation='nearest', origin='lower')
self._ims.append(im)
vert = ax.plot([0] * 2, [-0.5, self._sizes[yax] - 0.5],
color=(0, 1, 0), linestyle='-')[0]
horiz = ax.plot([-0.5, self._sizes[xax] - 0.5], [0] * 2,
color=(0, 1, 0), linestyle='-')[0]
self._crosshairs[ii] = dict(vert=vert, horiz=horiz)
# add text labels (top, right, bottom, left)
lims = [0, self._sizes[xax], 0, self._sizes[yax]]
bump = 0.01
poss = [[lims[1] / 2., lims[3]],
[(1 + bump) * lims[1], lims[3] / 2.],
[lims[1] / 2., 0],
[lims[0] - bump * lims[1], lims[3] / 2.]]
anchors = [['center', 'bottom'], ['left', 'center'],
['center', 'top'], ['right', 'center']]
for pos, anchor, lab in zip(poss, anchors, label):
ax.text(pos[0], pos[1], lab,
horizontalalignment=anchor[0],
verticalalignment=anchor[1])
ax.axis(lims)
ax.set_aspect(ratio)
ax.patch.set_visible(False)
ax.set_frame_on(False)
ax.axes.get_yaxis().set_visible(False)
ax.axes.get_xaxis().set_visible(False)
self._data_idx.append(0)
self._data_idx.append(-1) # volume
# Set up volumes axis
if self.n_volumes > 1 and len(self._axes) > 3:
ax = self._axes[3]
try:
ax.set_facecolor('k')
except AttributeError: # old mpl
ax.set_axis_bgcolor('k')
ax.set_title('Volumes')
y = np.zeros(self.n_volumes + 1)
x = np.arange(self.n_volumes + 1) - 0.5
step = ax.step(x, y, where='post', color='y')[0]
ax.set_xticks(np.unique(np.linspace(0, self.n_volumes - 1,
5).astype(int)))
ax.set_xlim(x[0], x[-1])
yl = [self._data.min(), self._data.max()]
yl = [l + s * np.diff(lims)[0] for l, s in zip(yl, [-1.01, 1.01])]
patch = mpl_patch.Rectangle([-0.5, yl[0]], 1., np.diff(yl)[0],
fill=True, facecolor=(0, 1, 0),
edgecolor=(0, 1, 0), alpha=0.25)
ax.add_patch(patch)
ax.set_ylim(yl)
self._volume_ax_objs = dict(step=step, patch=patch)
self._figs = set([a.figure for a in self._axes])
for fig in self._figs:
fig.canvas.mpl_connect('scroll_event', self._on_scroll)
fig.canvas.mpl_connect('motion_notify_event', self._on_mouse)
fig.canvas.mpl_connect('button_press_event', self._on_mouse)
fig.canvas.mpl_connect('key_press_event', self._on_keypress)
fig.canvas.mpl_connect('close_event', self._cleanup)
# actually set data meaningfully
self._position = np.zeros(4)
self._position[3] = 1. # convenience for affine multiplication
self._changing = False # keep track of status to avoid loops
self._links = [] # other viewers this one is linked to
self._plt.draw()
for fig in self._figs:
fig.canvas.draw()
self._set_volume_index(0, update_slices=False)
self._set_position(0., 0., 0.)
self._draw()