Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2# Copyright 2011-2019 Kwant authors.
3#
4# This file is part of Kwant. It is subject to the license terms in the file
5# LICENSE.rst found in the top-level directory of this distribution and at
6# https://kwant-project.org/license. A list of Kwant authors can be found in
7# the file AUTHORS.rst at the top-level directory of this distribution and at
8# https://kwant-project.org/authors.
10# This module is imported by plotter.py. It contains all the expensive imports
11# that we want to remove from plotter.py
13# All matplotlib imports must be isolated in a try, because even without
14# matplotlib iterators remain useful. Further, mpl_toolkits used for 3D
15# plotting are also imported separately, to ensure that 2D plotting works even
16# if 3D does not.
18import warnings
19from math import sqrt, pi
20import numpy as np
21from enum import Enum
24try:
25 __IPYTHON__
26 is_ipython_kernel = True
27except NameError:
28 is_ipython_kernel = False
30global mpl_available
31global plotly_available
32mpl_available = False
33plotly_available = False
35try:
36 import matplotlib
37 import matplotlib.colors
38 import matplotlib.cm
39 from matplotlib.figure import Figure
40 from matplotlib import collections
41 from . import _colormaps
42 from matplotlib.colors import ListedColormap
43 mpl_available = True
44 kwant_red_matplotlib = ListedColormap(_colormaps.kwant_red,
45 name="kwant red")
46 try:
47 from mpl_toolkits import mplot3d
48 has3d = True
49 except ImportError:
50 warnings.warn("3D plotting not available.", RuntimeWarning)
51 has3d = False
52except ImportError:
53 warnings.warn("matplotlib is not available, if other engines are "
54 "unavailable, only iterator-providing functions will work",
55 RuntimeWarning)
58try:
59 import plotly.offline as plotly_module
60 import plotly.graph_objs as plotly_graph_objs
61 init_notebook_mode_set = False
62 from . import _colormaps
63 plotly_available = True
65 _cmap_plotly = 255 * _colormaps.kwant_red
66 _cmap_levels = np.linspace(0, 1, len(_cmap_plotly))
67 kwant_red_plotly = [(level, 'rgb({},{},{})'.format(*rgb))
68 for level, rgb in zip(_cmap_levels, _cmap_plotly)]
69except ImportError:
70 warnings.warn("plotly is not available, if other engines are unavailable,"
71 " only iterator-providing functions will work",
72 RuntimeWarning)
74engines = []
75engine = None
77if plotly_available: 77 ↛ 81line 77 didn't jump to line 81, because the condition on line 77 was never false
78 engines.append("plotly")
79 engine = "plotly"
81if mpl_available: 81 ↛ 85line 81 didn't jump to line 85, because the condition on line 81 was never false
82 engines.append("matplotlib")
83 engine = "matplotlib"
85engines = frozenset(engines)
88# Collections that allow for symbols and linewiths to be given in data space
89# (not for general use, only implement what's needed for plotter)
90def isarray(var):
91 if hasattr(var, '__getitem__') and not isinstance(var, str):
92 return True
93 else:
94 return False
97def nparray_if_array(var):
98 return np.asarray(var) if isarray(var) else var
101if plotly_available: 101 ↛ 205line 101 didn't jump to line 205, because the condition on line 101 was never false
103 # The converter_map and converter_map_3d converts the common marker symbols
104 # of matplotlib to the symbols of plotly
105 converter_map = {
106 "o": 0,
107 "v": 6,
108 "^": 5,
109 "<": 7,
110 ">": 8,
111 "s": 1,
112 "+": 3,
113 "x": 4,
114 "*": 17,
115 "d": 2,
116 "h": 14,
117 "no symbol": -1
118 }
120 converter_map_3d = {
121 "o": "circle",
122 "s": "square",
123 "+": "cross",
124 "x": "x",
125 "d": "diamond",
126 }
128 def error_string(symbol_input, supported):
129 return 'Input symbol/s \'{}\' not supported. Only the following characters are supported: {}'.format(symbol_input, supported)
132 def convert_symbol_mpl_plotly(mpl_symbol):
133 if isarray(mpl_symbol): 133 ↛ 134line 133 didn't jump to line 134, because the condition on line 133 was never true
134 try:
135 converted_symbol = [converter_map.get(i) for i in mpl_symbol]
136 except KeyError:
137 raise RuntimeError( error_string(mpl_symbol, list(converter_map)) )
138 else:
139 try:
140 converted_symbol = converter_map.get(mpl_symbol)
141 except KeyError:
142 raise RuntimeError( error_string(mpl_symbol, list(converter_map)) )
143 return converted_symbol
146 def convert_symbol_mpl_plotly_3d(mpl_symbol):
147 if isarray(mpl_symbol): 147 ↛ 148line 147 didn't jump to line 148, because the condition on line 147 was never true
148 try:
149 converted_symbol = [converter_map_3d.get(i) for i in mpl_symbol]
150 except KeyError:
151 raise RuntimeError( error_string(mpl_symbol, list(converter_map_3d)) )
152 else:
153 try:
154 converted_symbol = converter_map_3d.get(mpl_symbol)
155 except KeyError:
156 raise RuntimeError( error_string(mpl_symbol, list(converter_map_3d)) )
157 return converted_symbol
160 def convert_site_size_mpl_plotly(mpl_site_size, plotly_ref_px):
161 # The conversion is such that we assume matplotlib's marker size is in
162 # square points (https://matplotlib.org/devdocs/api/_as_gen/matplotlib.pyplot.scatter.html)
163 # and we need to convert the points to pixels for plotly.
164 # Hence, 1 pixel = (96.0)/(72.0) point
165 return np.sqrt(mpl_site_size)*(96.0/72.0)*plotly_ref_px
168 def convert_colormap_mpl_plotly(r, g, b, a):
169 return f"rgba({255*r},{255*g},{255*b},{a})"
172 def convert_cmap_list_mpl_plotly(mpl_cmap_name):
173 if isinstance(mpl_cmap_name, str):
174 cmap = matplotlib.cm.get_cmap(mpl_cmap_name)
175 cmap_plotly_linear = [
176 (level, convert_colormap_mpl_plotly(*cmap(level)))
177 for level in np.linspace(0, 1, cmap.N)
178 ]
179 else:
180 assert(isinstance(mpl_cmap_name, list))
181 # Do not do any conversion if it's already a list
182 cmap_plotly_linear = mpl_cmap_name
183 return cmap_plotly_linear
186 def convert_lead_cmap_mpl_plotly(mpl_lead_cmap_init, mpl_lead_cmap_end,
187 N=255):
188 r_levels = np.linspace(mpl_lead_cmap_init[0],
189 mpl_lead_cmap_end[0], N) * 255
190 g_levels = np.linspace(mpl_lead_cmap_init[1],
191 mpl_lead_cmap_end[1], N) * 255
192 b_levels = np.linspace(mpl_lead_cmap_init[2],
193 mpl_lead_cmap_end[2], N) * 255
194 a_levels = np.linspace(mpl_lead_cmap_init[3],
195 mpl_lead_cmap_end[3], N)
196 level = np.linspace(0, 1, N)
197 cmap_plotly_linear = [(level, 'rgba({},{},{},{})'.format(*rgba))
198 for level, rgba in zip(level,
199 zip(r_levels, g_levels,
200 b_levels, a_levels
201 ))]
202 return cmap_plotly_linear
205if mpl_available: 205 ↛ 479line 205 didn't jump to line 479, because the condition on line 205 was never false
206 class LineCollection(collections.LineCollection):
207 def __init__(self, segments, reflen=None, **kwargs):
208 super().__init__(segments, **kwargs)
209 self.reflen = reflen
211 def set_linewidths(self, linewidths):
212 self.linewidths_orig = nparray_if_array(linewidths)
214 def draw(self, renderer):
215 if self.reflen is not None: 215 ↛ 221line 215 didn't jump to line 221, because the condition on line 215 was never false
216 # Note: only works for aspect ratio 1!
217 # 72.0 - there is 72 points in an inch
218 factor = (self.axes.transData.frozen().to_values()[0] * 72.0 *
219 self.reflen / self.figure.dpi)
220 else:
221 factor = 1
223 super().set_linewidths(self.linewidths_orig *
224 factor)
225 return super().draw(renderer)
228 class PathCollection(collections.PathCollection):
229 def __init__(self, paths, sizes=None, reflen=None, **kwargs):
230 super().__init__(paths, sizes=sizes, **kwargs)
232 self.reflen = reflen
233 self.linewidths_orig = nparray_if_array(self.get_linewidths())
235 self.transforms = np.array(
236 [matplotlib.transforms.Affine2D().scale(x).get_matrix()
237 for x in sizes])
239 def get_transforms(self):
240 return self.transforms
242 def get_transform(self):
243 Affine2D = matplotlib.transforms.Affine2D
244 if self.reflen is not None: 244 ↛ 250line 244 didn't jump to line 250, because the condition on line 244 was never false
245 # For the paths, use the data transformation but strip the
246 # offset (will be added later with offsets)
247 args = self.axes.transData.frozen().to_values()[:4] + (0, 0)
248 return Affine2D().from_values(*args).scale(self.reflen)
249 else:
250 return Affine2D().scale(self.figure.dpi / 72.0)
252 def draw(self, renderer):
253 if self.reflen: 253 ↛ 259line 253 didn't jump to line 259, because the condition on line 253 was never false
254 # Note: only works for aspect ratio 1!
255 factor = (self.axes.transData.frozen().to_values()[0] /
256 self.figure.dpi * 72.0 * self.reflen)
257 self.set_linewidths(self.linewidths_orig * factor)
259 return collections.Collection.draw(self, renderer)
262 if has3d: 262 ↛ 479line 262 didn't jump to line 479, because the condition on line 262 was never false
263 # Sorting is optional.
264 sort3d = True
266 # Compute the projection of a 3D length into 2D data coordinates
267 # for this we use 2 3D half-circles that are projected into 2D.
268 # (This gives the same length as projecting the full unit sphere.)
270 phi = np.linspace(0, pi, 21)
271 xyz = np.c_[np.cos(phi), np.sin(phi), 0 * phi].T.reshape(-1, 1, 21)
273 unit_sphere = np.block([
274 [xyz[0], xyz[2]],
275 [xyz[1], xyz[0]],
276 [xyz[2], xyz[1]],
277 ])
279 def projected_length(ax, length):
280 rc = np.array([ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()])
281 rc = np.apply_along_axis(np.sum, 1, rc) / 2.
283 rs = unit_sphere * length + rc.reshape(-1, 1)
285 transform = mplot3d.proj3d.proj_transform
286 rp = np.asarray(transform(*(list(rs) + [ax.get_proj()]))[:2])
287 rc[:2] = transform(*(list(rc) + [ax.get_proj()]))[:2]
289 coords = rp - np.repeat(rc[:2].reshape(-1, 1), len(rs[0]), axis=1)
290 return sqrt(np.sum(coords**2, axis=0).max())
293 # Auxiliary array for calculating corners of a cube.
294 corners = np.zeros((3, 8, 6), np.float_)
295 corners[0, [0, 1, 2, 3], 0] = corners[0, [4, 5, 6, 7], 1] = \
296 corners[0, [0, 1, 4, 5], 2] = corners[0, [2, 3, 6, 7], 3] = \
297 corners[0, [0, 2, 4, 6], 4] = corners[0, [1, 3, 5, 7], 5] = 1.0
300 class Line3DCollection(mplot3d.art3d.Line3DCollection):
301 def __init__(self, segments, reflen=None, zorder=0, **kwargs):
302 super().__init__(segments, **kwargs)
303 self.reflen = reflen
304 self.zorder3d = zorder
306 def set_linewidths(self, linewidths):
307 self.linewidths_orig = nparray_if_array(linewidths)
309 def do_3d_projection(self, renderer):
310 super().do_3d_projection(renderer)
311 # The whole 3D ordering is flawed in mplot3d when several
312 # collections are added. We just use normal zorder. Note the
313 # "-" due to the different logic in the 3d plotting, we still
314 # want larger zorder values to be plotted on top of smaller
315 # ones.
316 return -self.zorder3d
318 def draw(self, renderer):
319 if self.reflen: 319 ↛ 330line 319 didn't jump to line 330, because the condition on line 319 was never false
320 proj_len = projected_length(self.axes, self.reflen)
321 args = self.axes.transData.frozen().to_values()
322 # Note: unlike in the 2D case, where we can enforce equal
323 # aspect ratio, this (currently) does not work with
324 # 3D plots in matplotlib. As an approximation, we
325 # thus scale with the average of the x- and y-axis
326 # transformation.
327 factor = proj_len * (args[0] +
328 args[3]) * 0.5 * 72.0 / self.figure.dpi
329 else:
330 factor = 1
332 super().set_linewidths(
333 self.linewidths_orig * factor)
334 super().draw(renderer)
337 class Path3DCollection(mplot3d.art3d.Patch3DCollection):
338 def __init__(self, paths, sizes, reflen=None, zorder=0,
339 offsets=None, **kwargs):
340 paths = [matplotlib.patches.PathPatch(path) for path in paths]
342 if offsets is not None: 342 ↛ 347line 342 didn't jump to line 347, because the condition on line 342 was never false
343 kwargs['offsets'] = offsets[:, :2]
345 # Workaround for issue in Matplotlib-3.4.2 before PR merged
346 # https://github.com/matplotlib/matplotlib/pull/20416
347 self._z_markers_idx = slice(-1)
349 super().__init__(paths, **kwargs)
351 if offsets is not None: 351 ↛ 354line 351 didn't jump to line 354, because the condition on line 351 was never false
352 self.set_3d_properties(zs=offsets[:, 2], zdir="z")
354 self.reflen = reflen
355 self.zorder3d = zorder
357 self.paths_orig = np.array(paths, dtype='object')
358 self.linewidths_orig = nparray_if_array(self.get_linewidths())
359 self.linewidths_orig2 = self.linewidths_orig
360 self.array_orig = nparray_if_array(self.get_array())
361 self.facecolors_orig = nparray_if_array(self.get_facecolors())
362 self.edgecolors_orig = nparray_if_array(self.get_edgecolors())
364 Affine2D = matplotlib.transforms.Affine2D
365 self.orig_transforms = np.array(
366 [Affine2D().scale(x).get_matrix() for x in sizes])
367 self.transforms = self.orig_transforms
369 def set_array(self, array):
370 self.array_orig = nparray_if_array(array)
371 super().set_array(array)
373 def set_color(self, colors):
374 self.facecolors_orig = nparray_if_array(colors)
375 self.edgecolors_orig = self.facecolors_orig
376 super().set_color(colors)
378 def set_edgecolors(self, colors):
379 colors = matplotlib.colors.colorConverter.to_rgba_array(colors)
380 self.edgecolors_orig = nparray_if_array(colors)
381 super().set_edgecolors(colors)
383 def get_transforms(self):
384 # this is exact only for an isometric projection, for the
385 # perspective projection used in mplot3d it's an approximation
386 return self.transforms
388 def get_transform(self):
389 Affine2D = matplotlib.transforms.Affine2D
390 if self.reflen: 390 ↛ 398line 390 didn't jump to line 398, because the condition on line 390 was never false
391 proj_len = projected_length(self.axes, self.reflen)
393 # For the paths, use the data transformation but strip the
394 # offset (will be added later with the offsets).
395 args = self.axes.transData.frozen().to_values()[:4] + (0, 0)
396 return Affine2D().from_values(*args).scale(proj_len)
397 else:
398 return Affine2D().scale(self.figure.dpi / 72.0)
400 def do_3d_projection(self, renderer):
401 xs, ys, zs = self._offsets3d
403 # numpy complains about zero-length index arrays
404 if len(xs) == 0: 404 ↛ 405line 404 didn't jump to line 405, because the condition on line 404 was never true
405 return -self.zorder3d
407 proj = mplot3d.proj3d.proj_transform_clip
408 vs = np.array(proj(xs, ys, zs, renderer.M)[:3])
410 if sort3d: 410 ↛ 450line 410 didn't jump to line 450, because the condition on line 410 was never false
411 indx = vs[2].argsort()[::-1]
413 self.set_offsets(vs[:2, indx].T)
415 if len(self.paths_orig) > 1: 415 ↛ 416line 415 didn't jump to line 416, because the condition on line 415 was never true
416 paths = np.resize(self.paths_orig, (vs.shape[1],))
417 self.set_paths(paths[indx])
419 if len(self.orig_transforms) > 1: 419 ↛ 420line 419 didn't jump to line 420, because the condition on line 419 was never true
420 self.transforms = self.transforms[indx]
422 lw_orig = self.linewidths_orig
423 if (isinstance(lw_orig, np.ndarray) and len(lw_orig) > 1): 423 ↛ 424line 423 didn't jump to line 424, because the condition on line 423 was never true
424 self.linewidths_orig2 = np.resize(lw_orig,
425 (vs.shape[1],))[indx]
427 # Note: here array, facecolors and edgecolors are
428 # guaranteed to be 2d numpy arrays or None. (And
429 # array is the same length as the coordinates)
431 if self.array_orig is not None:
432 super(Path3DCollection,
433 self).set_array(self.array_orig[indx])
435 if (self.facecolors_orig is not None and
436 self.facecolors_orig.shape[0] > 1):
437 shape = list(self.facecolors_orig.shape)
438 shape[0] = vs.shape[1]
439 super().set_facecolors(
440 np.resize(self.facecolors_orig, shape)[indx])
442 if (self.edgecolors_orig is not None and 442 ↛ 444line 442 didn't jump to line 444, because the condition on line 442 was never true
443 self.edgecolors_orig.shape[0] > 1):
444 shape = list(self.edgecolors_orig.shape)
445 shape[0] = vs.shape[1]
446 super().set_edgecolors(
447 np.resize(self.edgecolors_orig,
448 shape)[indx])
449 else:
450 self.set_offsets(vs[:2].T)
452 # the whole 3D ordering is flawed in mplot3d when several
453 # collections are added. We just use normal zorder, but correct
454 # by the projected z-coord of the "center of gravity",
455 # normalized by the projected z-coord of the world coordinates.
456 # In doing so, several Path3DCollections are plotted probably
457 # in the right order (it's not exact) if they have the same
458 # zorder. Still, smaller and larger integer zorders are plotted
459 # below or on top.
461 bbox = np.asarray(self.axes.get_w_lims())
463 proj = mplot3d.proj3d.proj_transform_clip
464 cz = proj(*(list(np.dot(corners, bbox)) + [renderer.M]))[2]
466 return -self.zorder3d + vs[2].mean() / cz.ptp()
468 def draw(self, renderer):
469 if self.reflen: 469 ↛ 477line 469 didn't jump to line 477, because the condition on line 469 was never false
470 proj_len = projected_length(self.axes, self.reflen)
471 args = self.axes.transData.frozen().to_values()
472 factor = proj_len * (args[0] +
473 args[3]) * 0.5 * 72.0 / self.figure.dpi
475 self.set_linewidths(self.linewidths_orig2 * factor)
477 super().draw(renderer)
479if plotly_available: 479 ↛ exitline 479 didn't exit the module, because the condition on line 479 was never false
480 def matplotlib_to_plotly_cmap(cmap, pl_entries):
481 h = 1.0/(pl_entries-1)
482 pl_colorscale = []
484 for k in range(pl_entries):
485 C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
486 pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
488 return pl_colorscale