Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2# Copyright 2011-2019 Kwant authors.
3#
4# This file is part of Kwant. It is subject to the license terms in the file
5# LICENSE.rst found in the top-level directory of this distribution and at
6# https://kwant-project.org/license. A list of Kwant authors can be found in
7# the file AUTHORS.rst at the top-level directory of this distribution and at
8# https://kwant-project.org/authors.
10"""Plotter module for Kwant.
12This module provides iterators useful for any plotter routine, such as a list
13of system sites, their coordinates, lead sites at any lead unit cell, etc. If
14`matplotlib` is available, it also provides simple functions for plotting the
15system in two or three dimensions.
16"""
18from collections import defaultdict
19import sys
20import itertools
21import functools
22import warnings
23import cmath
24import numpy as np
25import tinyarray as ta
26from scipy import spatial, interpolate
27from math import cos, sin, pi, sqrt
29from . import system, builder, _common
30from ._common import deprecate_args
33__all__ = ['set_engine', 'get_engine',
34 'plot', 'map', 'bands', 'spectrum', 'current', 'density',
35 'interpolate_current', 'interpolate_density',
36 'streamplot', 'scalarplot',
37 'sys_leads_sites', 'sys_leads_hoppings', 'sys_leads_pos',
38 'sys_leads_hopping_pos', 'mask_interpolate']
40# All the expensive imports are done in _plotter.py. We lazy load the module
41# to avoid slowing down the initial import of Kwant.
42_p = _common.lazy_import('_plotter')
45def set_engine(engine):
46 """Set the plotting engine to use.
48 Parameters
49 ----------
50 engine : str
51 Options are: 'matplotlib', 'plotly'.
52 """
54 if ((_p.mpl_available) or (_p.plotly_available)): 54 ↛ 65line 54 didn't jump to line 65, because the condition on line 54 was never false
55 try:
56 assert(engine in _p.engines)
57 _p.engine = engine
58 except:
59 error_message = "Tried to set an unknown engine \'{}\'.".format(
60 engine)
61 error_message += " Supported engines are {}".format(
62 [e for e in _p.engines])
63 raise RuntimeError(error_message)
64 else:
65 warnings.warn("Tried to set \'{}\' but is not "
66 "available.".format(engine), RuntimeWarning)
68 if ((_p.engine == "plotly") and
69 (not _p.init_notebook_mode_set)):
70 if (_p.is_ipython_kernel): 70 ↛ 71line 70 didn't jump to line 71, because the condition on line 70 was never true
71 _p.init_notebook_mode_set = True
72 _p.plotly_module.init_notebook_mode(connected=True)
75def get_engine():
76 return _p.engine
79def _check_incompatible_args_plotly(dpi, fig_size, ax):
80 assert(_p.engine == "plotly")
81 if(dpi or fig_size or ax): 81 ↛ 82line 81 didn't jump to line 82, because the condition on line 81 was never true
82 raise RuntimeError(
83 "Plotly engine does not support setting 'dpi', 'fig_size' "
84 "or 'ax', either leave these parameters unspecified, or "
85 "select the matplotlib engine with"
86 "'kwant.plotter.set_engine(\"matplotlib\")'")
89def _sample_array(array, n_samples, rng=None):
90 rng = _common.ensure_rng(rng)
91 la = len(array)
92 return array[rng.choice(range(la), min(n_samples, la), replace=False)]
95# matplotlib helper functions.
97def _color_cycle():
98 """Infinitely cycle through colors from the matplotlib color cycle."""
99 props = _p.matplotlib.rcParams['axes.prop_cycle']
100 return itertools.cycle(x['color'] for x in props)
103def _make_figure(dpi, fig_size, use_pyplot=False):
104 if 'matplotlib.backends' not in sys.modules: 104 ↛ 105line 104 didn't jump to line 105, because the condition on line 104 was never true
105 warnings.warn(
106 "Kwant's plotting functions have\nthe side effect of "
107 "selecting the matplotlib backend. To avoid this "
108 "warning,\nimport matplotlib.pyplot, "
109 "matplotlib.backends or call matplotlib.use().",
110 RuntimeWarning, stacklevel=3
111 )
112 if use_pyplot: 112 ↛ 117line 112 didn't jump to line 117, because the condition on line 112 was never true
113 # We import backends and pyplot only at the last possible moment (=now)
114 # because this has the side effect of selecting the matplotlib backend
115 # for good. Warn if backend has not been set yet. This check is the
116 # same as the one performed inside matplotlib.use.
117 from matplotlib import pyplot
118 fig = pyplot.figure()
119 else:
120 from matplotlib.backends.backend_agg import FigureCanvasAgg
121 fig = _p.Figure()
122 fig.canvas = FigureCanvasAgg(fig)
123 if dpi is not None:
124 fig.set_dpi(dpi)
125 if fig_size is not None:
126 fig.set_figwidth(fig_size[0])
127 fig.set_figheight(fig_size[1])
128 return fig
131def _maybe_output_fig(fig, file=None, show=True):
132 """Output a matplotlib figure using a given output mode.
134 Parameters
135 ----------
136 fig : matplotlib.figure.Figure instance
137 The figure to be output.
138 file : string or a file object
139 The name of the target file or the target file itself
140 (opened for writing).
141 show : bool
142 Whether to call ``matplotlib.pyplot.show()``. Only has an effect if
143 not saving to a file.
145 Notes
146 -----
147 The behavior of this function producing a file is different from that of
148 matplotlib in that the `dpi` attribute of the figure is used by defaul
149 instead of the matplotlib config setting.
150 """
151 if fig is None:
152 return
154 if _p.engine == "matplotlib":
155 if file is not None: 155 ↛ 157line 155 didn't jump to line 157, because the condition on line 155 was never false
156 fig.canvas.print_figure(file, dpi=fig.dpi)
157 elif show:
158 # If there was no file provided, pyplot should already be available
159 # and we can import it safely without additional warnings.
160 from matplotlib import pyplot
161 pyplot.show()
162 elif _p.engine == "plotly": 162 ↛ exitline 162 didn't return from function '_maybe_output_fig', because the condition on line 162 was never false
163 if file is not None: 163 ↛ 165line 163 didn't jump to line 165, because the condition on line 163 was never false
164 _p.plotly_module.plot(fig, show_link=False, filename=file, auto_open=False)
165 if show: 165 ↛ 166line 165 didn't jump to line 166, because the condition on line 165 was never true
166 if (_p.is_ipython_kernel):
167 _p.plotly_module.iplot(fig)
168 else:
169 raise RuntimeError('show flag using the plotly engine can '
170 'only be True if and only if called from a '
171 'jupyter/ipython environment.')
174def set_colors(color, collection, cmap, norm=None):
175 """Process a color specification to a format accepted by collections.
177 Parameters
178 ----------
179 color : color specification
180 collection : instance of a subclass of ``matplotlib.collections.Collection``
181 Collection to which the color is added.
182 cmap : ``matplotlib`` color map specification or None
183 Color map to be used if colors are specified as floats.
184 norm : ``matplotlib`` color norm
185 Norm to be used if colors are specified as floats.
186 """
188 length = max(len(collection.get_paths()), len(collection.get_offsets()))
190 # matplotlib gets confused if dtype='object'
191 if (isinstance(color, np.ndarray) and color.dtype == np.dtype('object')): 191 ↛ 192line 191 didn't jump to line 192, because the condition on line 191 was never true
192 color = tuple(color)
194 if _p.has3d and isinstance(collection, _p.mplot3d.art3d.Line3DCollection): 194 ↛ 195line 194 didn't jump to line 195, because the condition on line 194 was never true
195 length = len(collection._segments3d) # Once again, matplotlib fault!
197 if _p.isarray(color) and len(color) == length:
198 try:
199 # check if it is an array of floats for color mapping
200 color = np.asarray(color, dtype=float)
201 if color.ndim == 1:
202 collection.set_array(color)
203 collection.set_cmap(cmap)
204 collection.set_norm(norm)
205 collection.set_color(None)
206 return
207 except (TypeError, ValueError):
208 pass
210 colors = _p.matplotlib.colors.colorConverter.to_rgba_array(color)
211 collection.set_color(colors)
214def percentile_bound(data, vmin, vmax, percentile=96, stretch=0.1):
215 """Return the bounds that captures at least 'percentile' of 'data'.
217 If 'vmin' or 'vmax' are provided, then the corresponding bound is
218 exactly 'vmin' or 'vmax'. First we set the bounds such that the
219 provided percentile of the data is within them. Then we try to
220 extend the bounds to cover all the data, maximally stretching each
221 bound by a factor 'stretch'.
222 """
223 if vmin is not None and vmax is not None: 223 ↛ 224line 223 didn't jump to line 224, because the condition on line 223 was never true
224 return vmin, vmax
226 percentile = (100 - percentile) / 2
227 percentiles = (0, percentile, 100 - percentile, 100)
228 mn, bound_mn, bound_mx, mx = np.percentile(data.flatten(), percentiles)
230 bound_mn = bound_mn if vmin is None else vmin
231 bound_mx = bound_mx if vmax is None else vmax
233 # Stretch the lower and upper bounds to cover all the data, if
234 # we stretch the bound by less than a factor 'stretch'.
235 stretch = (bound_mx - bound_mn) * stretch
236 out_mn = max(bound_mn - stretch, mn) if vmin is None else vmin
237 out_mx = min(bound_mx + stretch, mx) if vmax is None else vmax
239 return (out_mn, out_mx)
242symbol_dict = {'O': 'o', 's': ('p', 4, 45), 'S': ('P', 4, 45)}
244def get_symbol(symbols):
245 """Return the path corresponding to the description in ``symbols``"""
246 # Figure out if list of symbols or single symbol.
247 if not hasattr(symbols, '__getitem__'): 247 ↛ 248line 247 didn't jump to line 248, because the condition on line 247 was never true
248 symbols = [symbols]
249 elif len(symbols) == 3 and symbols[0] in ('p', 'P'): 249 ↛ 252line 249 didn't jump to line 252, because the condition on line 249 was never true
250 # Most likely a polygon specification (at least not a valid other
251 # symbol).
252 symbols = [symbols]
254 symbols = [symbol_dict[symbol] if symbol in symbol_dict else symbol for
255 symbol in symbols]
257 paths = []
258 for symbol in symbols:
259 if isinstance(symbol, _p.matplotlib.path.Path): 259 ↛ 260line 259 didn't jump to line 260, because the condition on line 259 was never true
260 return symbol
261 elif hasattr(symbol, '__getitem__') and len(symbol) == 3: 261 ↛ 262line 261 didn't jump to line 262, because the condition on line 261 was never true
262 kind, n, angle = symbol
264 if kind in ['p', 'P']:
265 if kind == 'p':
266 radius = 1. / cos(pi / n)
267 else:
268 # make the polygon such that it has area equal
269 # to a unit circle
270 radius = sqrt(2 * pi / (n * sin(2 * pi / n)))
272 angle = pi * angle / 180
273 patch = _p.matplotlib.patches.RegularPolygon((0, 0), n,
274 radius=radius,
275 orientation=angle)
276 else:
277 raise ValueError("Unknown symbol definition " + str(symbol))
278 elif symbol == 'o': 278 ↛ 281line 278 didn't jump to line 281, because the condition on line 278 was never false
279 patch = _p.matplotlib.patches.Circle((0, 0), 1)
281 paths.append(patch.get_path().transformed(patch.get_transform()))
283 return paths
286def symbols(axes, pos, symbol='o', size=1, reflen=None, facecolor='k',
287 edgecolor='k', linewidth=None, cmap=None, norm=None, zorder=0,
288 **kwargs):
289 """Add a collection of symbols (2D or 3D) to an axes instance.
291 Parameters
292 ----------
293 axes : matplotlib.axes.Axes instance
294 Axes to which the lines have to be added.
295 pos0 : 2d or 3d array_like
296 Coordinates of each symbol.
297 symbol: symbol definition.
298 TODO To be written.
299 size: float or 1d array
300 Size(s) of the symbols. Defaults to 1.
301 reflen: float or None, optional
302 If ``reflen`` is ``None``, the symbol sizes and linewidths are
303 given in points (absolute size in the figure space). If
304 ``reflen`` is a number, the symbol sizes and linewidths are
305 given in units of ``reflen`` in data space (i.e. scales with the
306 scale of the plot). Defaults to ``None``.
307 facecolor: color definition, optional
308 edgecolor: color definition, optional
309 Defines the fill and edge color of the symbol, repsectively.
310 Either a single object that is a proper matplotlib color
311 definition or a sequence of such objects of appropriate
312 length. Defaults to all black.
313 cmap : ``matplotlib`` color map specification or None
314 Color map to be used if colors are specified as floats.
315 norm : ``matplotlib`` color norm
316 Norm to be used if colors are specified as floats.
317 zorder: int
318 Order in which different collections are drawn: larger
319 ``zorder`` means the collection is drawn over collections with
320 smaller ``zorder`` values.
321 **kwargs : dict keyword arguments to
322 pass to `PathCollection` or `Path3DCollection`, respectively.
324 Returns
325 -------
326 `PathCollection` or `Path3DCollection` instance containing all the
327 symbols that were added.
328 """
330 dim = pos.shape[1]
331 assert dim == 2 or dim == 3
333 #internally, size must be array_like
334 try:
335 size[0]
336 except TypeError:
337 size = (size, )
339 if dim == 2:
340 Collection = _p.PathCollection
341 else:
342 Collection = _p.Path3DCollection
344 if len(pos) == 0 or np.all(symbol == 'no symbol') or np.all(size == 0): 344 ↛ 345line 344 didn't jump to line 345, because the condition on line 344 was never true
345 paths = []
346 pos = np.empty((0, dim))
347 else:
348 paths = get_symbol(symbol)
350 coll = Collection(paths, sizes=size, reflen=reflen, linewidths=linewidth,
351 offsets=pos, transOffset=axes.transData, zorder=zorder)
353 set_colors(facecolor, coll, cmap, norm)
354 coll.set_edgecolors(edgecolor)
356 coll.update(kwargs)
358 if dim == 2:
359 axes.add_collection(coll)
360 else:
361 axes.add_collection3d(coll)
363 return coll
366def lines(axes, pos0, pos1, reflen=None, colors='k', linestyles='solid',
367 cmap=None, norm=None, zorder=0, **kwargs):
368 """Add a collection of line segments (2D or 3D) to an axes instance.
370 Parameters
371 ----------
372 axes : matplotlib.axes.Axes instance
373 Axes to which the lines have to be added.
374 pos0 : 2d or 3d array_like
375 Starting coordinates of each line segment
376 pos1 : 2d or 3d array_like
377 Ending coordinates of each line segment
378 reflen: float or None, optional
379 If `reflen` is `None`, the linewidths are given in points (absolute
380 size in the figure space). If `reflen` is a number, the linewidths
381 are given in units of `reflen` in data space (i.e. scales with
382 the scale of the plot). Defaults to `None`.
383 colors : color definition, optional
384 Either a single object that is a proper matplotlib color definition
385 or a sequence of such objects of appropriate length. Defaults to all
386 segments black.
387 linestyles :linestyle definition, optional
388 Either a single object that is a proper matplotlib line style
389 definition or a sequence of such objects of appropriate length.
390 Defaults to all segments solid.
391 cmap : ``matplotlib`` color map specification or None
392 Color map to be used if colors are specified as floats.
393 norm : ``matplotlib`` color norm
394 Norm to be used if colors are specified as floats.
395 zorder: int
396 Order in which different collections are drawn: larger
397 `zorder` means the collection is drawn over collections with
398 smaller `zorder` values.
399 **kwargs : dict keyword arguments to
400 pass to `LineCollection` or `Line3DCollection`, respectively.
402 Returns
403 -------
404 `LineCollection` or `Line3DCollection` instance containing all the
405 segments that were added.
406 """
408 if not pos0.shape == pos1.shape: 408 ↛ 409line 408 didn't jump to line 409, because the condition on line 408 was never true
409 raise ValueError('Incompatible lengths of coordinate arrays.')
411 dim = pos0.shape[1]
412 assert dim == 2 or dim == 3
413 if dim == 2:
414 Collection = _p.LineCollection
415 else:
416 Collection = _p.Line3DCollection
418 if (len(pos0) == 0 or
419 ('linewidths' in kwargs and kwargs['linewidths'] == 0)):
420 coll = Collection([], reflen=reflen, linestyles=linestyles,
421 zorder=zorder)
422 coll.update(kwargs)
423 if dim == 2:
424 axes.add_collection(coll)
425 else:
426 axes.add_collection3d(coll)
427 return coll
429 segments = np.c_[pos0, pos1].reshape(pos0.shape[0], 2, dim)
431 coll = Collection(segments, reflen=reflen, linestyles=linestyles,
432 zorder=zorder)
433 set_colors(colors, coll, cmap, norm)
434 coll.update(kwargs)
436 if dim == 2: 436 ↛ 439line 436 didn't jump to line 439, because the condition on line 436 was never false
437 axes.add_collection(coll)
438 else:
439 axes.add_collection3d(coll)
441 return coll
444# Extracting necessary data from the system.
446def sys_leads_sites(sys, num_lead_cells=2):
447 """Return all the sites of the system and of the leads as a list.
449 Parameters
450 ----------
451 sys : kwant.builder.Builder or kwant.system.System instance
452 The system, sites of which should be returned.
453 num_lead_cells : integer
454 The number of times lead sites from each lead should be returned.
455 This is useful for showing several unit cells of the lead next to the
456 system.
458 Returns
459 -------
460 sites : list of (site, lead_number, copy_number) tuples
461 A site is a `~kwant.system.Site` instance if the system is not finalized,
462 and an integer otherwise. For system sites `lead_number` is `None` and
463 `copy_number` is `0`, for leads both are integers.
464 lead_cells : list of slices
465 `lead_cells[i]` gives the position of all the coordinates of lead
466 `i` within `sites`.
468 Notes
469 -----
470 Leads are only supported if they are of the same type as the original
471 system, i.e. sites of `~kwant.builder.BuilderLead` leads are returned with an
472 unfinalized system, and sites of ``system.InfiniteSystem`` leads are
473 returned with a finalized system.
474 """
475 syst = sys # for naming consistency within function bodies
476 lead_cells = []
477 if isinstance(syst, builder.Builder):
478 sites = [(site, None, 0) for site in syst.sites()]
479 for leadnr, lead in enumerate(syst.leads):
480 start = len(sites)
481 if hasattr(lead, 'builder') and len(lead.interface): 481 ↛ 485line 481 didn't jump to line 485, because the condition on line 481 was never false
482 sites.extend(((site, leadnr, i) for site in
483 lead.builder.sites() for i in
484 range(num_lead_cells)))
485 lead_cells.append(slice(start, len(sites)))
486 elif system.is_finite(syst): 486 ↛ 498line 486 didn't jump to line 498, because the condition on line 486 was never false
487 sites = [(i, None, 0) for i in range(syst.graph.num_nodes)]
488 for leadnr, lead in enumerate(syst.leads):
489 start = len(sites)
490 # We will only plot leads with a graph and with a symmetry.
491 if (hasattr(lead, 'graph') and hasattr(lead, 'symmetry') and 491 ↛ 496line 491 didn't jump to line 496, because the condition on line 491 was never false
492 len(syst.lead_interfaces[leadnr])):
493 sites.extend(((site, leadnr, i) for site in
494 range(lead.cell_size) for i in
495 range(num_lead_cells)))
496 lead_cells.append(slice(start, len(sites)))
497 else:
498 raise TypeError('Unrecognized system type.')
499 return sites, lead_cells
502def sys_leads_pos(sys, site_lead_nr):
503 """Return an array of positions of sites in a system.
505 Parameters
506 ----------
507 sys : `kwant.builder.Builder` or `kwant.system.System` instance
508 The system, coordinates of sites of which should be returned.
509 site_lead_nr : list of `(site, leadnr, copynr)` tuples
510 Output of `sys_leads_sites` applied to the system.
512 Returns
513 -------
514 coords : numpy.ndarray of floats
515 Array of coordinates of the sites.
517 Notes
518 -----
519 This function uses `site.pos` property to get the position of a builder
520 site and `sys.pos(sitenr)` for finalized systems. This function requires
521 that all the positions of all the sites have the same dimensionality.
522 """
524 # Note about efficiency (also applies to sys_leads_hoppings_pos)
525 # NumPy is really slow when making a NumPy array from a tinyarray
526 # (buffer interface seems very slow). It's much faster to first
527 # convert to a tuple and then to convert to numpy array ...
529 syst = sys # for naming consistency inside function bodies
530 is_builder = isinstance(syst, builder.Builder)
531 num_lead_cells = site_lead_nr[-1][2] + 1
532 if is_builder:
533 pos = np.array(ta.array([i[0].pos for i in site_lead_nr]))
534 else:
535 syst_from_lead = lambda lead: (syst if (lead is None)
536 else syst.leads[lead])
537 pos = np.array(ta.array([syst_from_lead(i[1]).pos(i[0])
538 for i in site_lead_nr]))
539 if pos.dtype == object: # Happens if not all the pos are same length. 539 ↛ 540line 539 didn't jump to line 540, because the condition on line 539 was never true
540 raise ValueError("pos attribute of the sites does not have consistent"
541 " values.")
542 dim = pos.shape[1]
544 def get_vec_domain(lead_nr):
545 if lead_nr is None: 545 ↛ 546line 545 didn't jump to line 546, because the condition on line 545 was never true
546 return np.zeros((dim,)), 0
547 if is_builder:
548 sym = syst.leads[lead_nr].builder.symmetry
549 try:
550 site = syst.leads[lead_nr].interface[0]
551 except IndexError:
552 return (0, 0)
553 else:
554 try:
555 sym = syst.leads[lead_nr].symmetry
556 site = syst.sites[syst.lead_interfaces[lead_nr][0]]
557 except (AttributeError, IndexError):
558 # empty leads, or leads without symmetry aren't drawn anyways
559 return (0, 0)
560 dom = sym.which(site)[0] + 1
561 # Conversion to numpy array here useful for efficiency
562 vec = np.array(sym.periods)[0]
563 return vec, dom
564 vecs_doms = dict((i, get_vec_domain(i)) for i in range(len(syst.leads)))
565 vecs_doms[None] = np.zeros((dim,)), 0
566 for k, v in vecs_doms.items():
567 vecs_doms[k] = [v[0] * i for i in range(v[1], v[1] + num_lead_cells)]
568 pos += [vecs_doms[i[1]][i[2]] for i in site_lead_nr]
569 return pos
572def sys_leads_hoppings(sys, num_lead_cells=2):
573 """Return all the hoppings of the system and of the leads as an iterator.
575 Parameters
576 ----------
577 sys : kwant.builder.Builder or kwant.system.System instance
578 The system, sites of which should be returned.
579 num_lead_cells : integer
580 The number of times lead sites from each lead should be returned.
581 This is useful for showing several unit cells of the lead next to the
582 system.
584 Returns
585 -------
586 hoppings : list of (hopping, lead_number, copy_number) tuples
587 A site is a `~kwant.system.Site` instance if the system is not finalized,
588 and an integer otherwise. For system sites `lead_number` is `None` and
589 `copy_number` is `0`, for leads both are integers.
590 lead_cells : list of slices
591 `lead_cells[i]` gives the position of all the coordinates of lead
592 `i` within `hoppings`.
594 Notes
595 -----
596 Leads are only supported if they are of the same type as the original
597 system, i.e. hoppings of `~kwant.builder.BuilderLead` leads are returned with an
598 unfinalized system, and hoppings of `~kwant.system.InfiniteSystem` leads are
599 returned with a finalized system.
600 """
602 syst = sys # for naming consistency inside function bodies
603 hoppings = []
604 lead_cells = []
605 if isinstance(syst, builder.Builder):
606 hoppings.extend(((hop, None, 0) for hop in syst.hoppings()))
608 def lead_hoppings(lead):
609 sym = lead.symmetry
610 for site2, site1 in lead.hoppings():
611 shift1 = sym.which(site1)[0]
612 shift2 = sym.which(site2)[0]
613 # We need to make sure that the hopping is between a site in a
614 # fundamental domain and a site with a negative domain. The
615 # direction of the hopping is chosen arbitrarily
616 # NOTE(Anton): This may need to be revisited with the future
617 # builder format changes.
618 shift = max(shift1, shift2)
619 yield sym.act([-shift], site2), sym.act([-shift], site1)
621 for leadnr, lead in enumerate(syst.leads):
622 start = len(hoppings)
623 if hasattr(lead, 'builder') and len(lead.interface): 623 ↛ 627line 623 didn't jump to line 627, because the condition on line 623 was never false
624 hoppings.extend(((hop, leadnr, i) for hop in
625 lead_hoppings(lead.builder) for i in
626 range(num_lead_cells)))
627 lead_cells.append(slice(start, len(hoppings)))
628 elif isinstance(syst, system.System): 628 ↛ 645line 628 didn't jump to line 645, because the condition on line 628 was never false
629 def ll_hoppings(syst):
630 for i in range(syst.graph.num_nodes):
631 for j in syst.graph.out_neighbors(i): 631 ↛ 632line 631 didn't jump to line 632, because the loop on line 631 never started
632 if i < j:
633 yield i, j
635 hoppings.extend(((hop, None, 0) for hop in ll_hoppings(syst)))
636 for leadnr, lead in enumerate(syst.leads): 636 ↛ 637line 636 didn't jump to line 637, because the loop on line 636 never started
637 start = len(hoppings)
638 # We will only plot leads with a graph and with a symmetry.
639 if (hasattr(lead, 'graph') and hasattr(lead, 'symmetry') and
640 len(syst.lead_interfaces[leadnr])):
641 hoppings.extend(((hop, leadnr, i) for hop in ll_hoppings(lead)
642 for i in range(num_lead_cells)))
643 lead_cells.append(slice(start, len(hoppings)))
644 else:
645 raise TypeError('Unrecognized system type.')
646 return hoppings, lead_cells
649def sys_leads_hopping_pos(sys, hop_lead_nr):
650 """Return arrays of coordinates of all hoppings in a system.
652 Parameters
653 ----------
654 sys : ``~kwant.builder.Builder`` or ``~kwant.system.System`` instance
655 The system, coordinates of sites of which should be returned.
656 hoppings : list of ``(hopping, leadnr, copynr)`` tuples
657 Output of `sys_leads_hoppings` applied to the system.
659 Returns
660 -------
661 coords : (end_site, start_site): tuple of NumPy arrays of floats
662 Array of coordinates of the hoppings. The first half of coordinates
663 in each array entry are those of the first site in the hopping, the
664 last half are those of the second site.
666 Notes
667 -----
668 This function uses ``site.pos`` property to get the position of a builder
669 site and ``sys.pos(sitenr)`` for finalized systems. This function requires
670 that all the positions of all the sites have the same dimensionality.
671 """
673 syst = sys # for naming consistency inside function bodies
674 is_builder = isinstance(syst, builder.Builder)
675 if len(hop_lead_nr) == 0:
676 return np.empty((0, 3)), np.empty((0, 3))
677 num_lead_cells = hop_lead_nr[-1][2] + 1
678 if is_builder: 678 ↛ 683line 678 didn't jump to line 683, because the condition on line 678 was never false
679 pos = np.array(ta.array([ta.array(tuple(i[0][0].pos) +
680 tuple(i[0][1].pos)) for i in
681 hop_lead_nr]))
682 else:
683 syst_from_lead = lambda lead: (syst if (lead is None) else
684 syst.leads[lead])
685 pos = ta.array([ta.array(tuple(syst_from_lead(i[1]).pos(i[0][0])) +
686 tuple(syst_from_lead(i[1]).pos(i[0][1]))) for i
687 in hop_lead_nr])
688 pos = np.array(pos)
689 if pos.dtype == object: # Happens if not all the pos are same length. 689 ↛ 690line 689 didn't jump to line 690, because the condition on line 689 was never true
690 raise ValueError("pos attribute of the sites does not have consistent"
691 " values.")
692 dim = pos.shape[1]
694 def get_vec_domain(lead_nr):
695 if lead_nr is None: 695 ↛ 696line 695 didn't jump to line 696, because the condition on line 695 was never true
696 return np.zeros((dim,)), 0
697 if is_builder: 697 ↛ 704line 697 didn't jump to line 704, because the condition on line 697 was never false
698 sym = syst.leads[lead_nr].builder.symmetry
699 try:
700 site = syst.leads[lead_nr].interface[0]
701 except IndexError:
702 return (0, 0)
703 else:
704 try:
705 sym = syst.leads[lead_nr].symmetry
706 site = syst.sites[syst.lead_interfaces[lead_nr][0]]
707 except (AttributeError, IndexError):
708 # empyt leads or leads without symmetry are not drawn anyways
709 return (0, 0)
710 dom = sym.which(site)[0] + 1
711 vec = np.array(sym.periods)[0]
712 return np.r_[vec, vec], dom
714 vecs_doms = dict((i, get_vec_domain(i)) for i in range(len(syst.leads)))
715 vecs_doms[None] = np.zeros((dim,)), 0
716 for k, v in vecs_doms.items():
717 vecs_doms[k] = [v[0] * i for i in range(v[1], v[1] + num_lead_cells)]
718 pos += [vecs_doms[i[1]][i[2]] for i in hop_lead_nr]
719 return np.copy(pos[:, : dim // 2]), np.copy(pos[:, dim // 2:])
722# Useful plot functions (to be extended).
723# The default plotly symbol size is a 6 px
724# The keys of 2, and 3 represent the dimension of the system.
725# e.g. the default for site_size for kwant system of dim=2 is 0.25, and
726# dim=3 is 0.5
727defaults = {'site_symbol': {2: 'o', 3: 'o'},
728 'site_size': {2: 0.25, 3: 0.5},
729 'plotly_site_size_reference': 6,
730 'site_color': {2: 'black', 3: 'white'},
731 'site_edgecolor': {2: 'black', 3: 'black'},
732 'site_lw': {2: 0, 3: 0.1},
733 'hop_color': {2: 'black', 3: 'black'},
734 'hop_lw': {2: 0.1, 3: 0},
735 'lead_color': {2: 'red', 3: 'red'}}
738def plot(sys, num_lead_cells=2, unit=None,
739 site_symbol=None, site_size=None,
740 site_color=None, site_edgecolor=None, site_lw=None,
741 hop_color=None, hop_lw=None,
742 lead_site_symbol=None, lead_site_size=None, lead_color=None,
743 lead_site_edgecolor=None, lead_site_lw=None,
744 lead_hop_lw=None, pos_transform=None,
745 cmap='gray', colorbar=True, file=None,
746 show=True, dpi=None, fig_size=None, ax=None):
747 """Plot a system in 2 or 3 dimensions.
749 An alias exists for this common name: ``kwant.plot``.
751 Parameters
752 ----------
753 sys : kwant.builder.Builder or kwant.system.FiniteSystem
754 A system to be plotted.
755 num_lead_cells : int
756 Number of lead copies to be shown with the system.
757 unit : 'nn', 'pt', or float
758 The unit used to specify symbol sizes and linewidths.
759 Possible choices are:
761 - 'nn': unit is the shortest hopping or a typical nearst neighbor
762 distance in the system if there are no hoppings. This means that
763 symbol sizes/linewidths will scale as the zoom level of the figure is
764 changed. Very short distances are discarded before searching for the
765 shortest. This choice means that the symbols will scale if the
766 figure is zoomed.
767 - 'pt': unit is points (point = 1/72 inch) in figure space. This means
768 that symbols and linewidths will always be drawn with the same size
769 independent of zoom level of the plot.
770 - float: sizes are given in units of this value in real (system) space,
771 and will accordingly scale as the plot is zoomed.
773 The default value is 'nn', which allows to ensure that the images
774 neighboring sites do not overlap.
776 site_symbol : symbol specification, function, array, or `None`
777 Symbol used for representing a site in the plot. Can be specified as
779 - 'o': circle with radius of 1 unit.
780 - 's': square with inner circle radius of 1 unit.
781 - ``('p', nvert, angle)``: regular polygon with ``nvert`` vertices,
782 rotated by ``angle``. ``angle`` is given in degrees, and ``angle=0``
783 corresponds to one edge of the polygon pointing upward. The
784 radius of the inner circle is 1 unit. [Unsupported by plotly engine]
785 - 'no symbol': no symbol is plotted. [Unsupported by plotly engine]
786 - 'S', `('P', nvert, angle)`: as the lower-case variants described
787 above, but with an area equal to a circle of radius 1. (Makes
788 the visual size of the symbol equal to the size of a circle with
789 radius 1). [Unsupported by plotly engine]
790 - matplotlib.path.Path instance. [Unsupported by plotly engine]
792 Instead of a single symbol, different symbols can be specified
793 for different sites by passing a function that returns a valid
794 symbol specification for each site, or by passing an array of
795 symbols specifications (only for kwant.system.FiniteSystem).
796 site_size : number, function, array, or `None`
797 Relative (linear) size of the site symbol.
798 An array may not be used when 'syst' is a kwant.Builder.
799 site_color : ``matplotlib`` color description, function, array, or `None`
800 A color used for plotting a site in the system. If a colormap is used,
801 it should be a function returning single floats or a one-dimensional
802 array of floats. By default sites are colored by their site family,
803 using the current matplotlib color cycle.
804 An array of colors may not be used when 'syst' is a kwant.Builder.
805 site_edgecolor : ``matplotlib`` color description, function, array, or `None`
806 Color used for plotting the edges of the site symbols. Only
807 valid matplotlib color descriptions are allowed (and no
808 combination of floats and colormap as for site_color).
809 An array of colors may not be used when 'syst' is a kwant.Builder.
810 site_lw : number, function, array, or `None`
811 Linewidth of the site symbol edges.
812 An array may not be used when 'syst' is a kwant.Builder.
813 hop_color : ``matplotlib`` color description or a function
814 Same as `site_color`, but for hoppings. A function is passed two sites
815 in this case. (arrays are not allowed in this case).
816 hop_lw : number, function, or `None`
817 Linewidth of the hoppings.
818 lead_site_symbol : symbol specification or `None`
819 Symbol to be used for the leads. See `site_symbol` for allowed
820 specifications. Note that for leads, only constants
821 (i.e. no functions or arrays) are allowed. If None, then
822 `site_symbol` is used if it is constant (i.e. no function or array),
823 the default otherwise. The same holds for the other lead properties
824 below.
825 lead_site_size : number or `None`
826 Relative (linear) size of the lead symbol
827 lead_color : ``matplotlib`` color description or `None`
828 For the leads, `num_lead_cells` copies of the lead unit cell
829 are plotted. They are plotted in color fading from `lead_color`
830 to white (alpha values in `lead_color` are supported) when moving
831 from the system into the lead. Is also applied to the
832 hoppings.
833 lead_site_edgecolor : ``matplotlib`` color description or `None`
834 Color of the symbol edges (no fading done).
835 lead_site_lw : number or `None`
836 Linewidth of the lead symbols.
837 lead_hop_lw : number or `None`
838 Linewidth of the lead hoppings.
839 cmap : ``matplotlib`` color map or a sequence of two color maps or `None`
840 The color map used for sites and optionally hoppings.
841 pos_transform : function or `None`
842 Transformation to be applied to the site position.
843 colorbar : bool
844 Whether to show a colorbar if colormap is used. Ignored if `ax` is
845 provided.
846 file : string or file object or `None`
847 The output file. If `None`, output will be shown instead.
848 show : bool
849 Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
850 to be shown immediately. Defaults to `True`.
851 dpi : float or `None`
852 Number of pixels per inch. If not set the ``matplotlib`` default is
853 used.
854 fig_size : tuple or `None`
855 Figure size `(width, height)` in inches. If not set, the default
856 ``matplotlib`` value is used.
857 ax : ``matplotlib.axes.Axes`` instance or `None`
858 If `ax` is not `None`, no new figure is created, but the plot is done
859 within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
860 and `fig_size` are ignored.
862 Returns
863 -------
864 fig : matplotlib figure
865 A figure with the output if `ax` is not set, else None.
867 Notes
868 -----
869 - If `None` is passed for a plot property, a default value depending on
870 the dimension is chosen. Typically, the default values result in
871 acceptable plots.
873 - The meaning of "site" depends on whether the system to be plotted is a
874 builder or a low level system. For builders, a site is a
875 kwant.system.Site object. For low level systems, a site is an integer
876 -- the site number.
878 - color and symbol definitions may be tuples, but not lists or arrays.
879 Arrays of values (linewidths, colors, sizes) may not be tuples.
881 - The dimensionality of the plot (2D vs 3D) is inferred from the coordinate
882 array. If there are more than three coordinates, only the first three
883 are used. If there is just one coordinate, the second one is padded with
884 zeros.
886 - The system is scaled to fit the smaller dimension of the figure, given
887 its aspect ratio.
889 """
891 # Provide default unit if user did not specify
892 if _p.engine == "matplotlib":
893 fig = _plot_matplotlib(sys, num_lead_cells, unit,
894 site_symbol, site_size,
895 site_color, site_edgecolor, site_lw,
896 hop_color, hop_lw,
897 lead_site_symbol, lead_site_size, lead_color,
898 lead_site_edgecolor, lead_site_lw,
899 lead_hop_lw, pos_transform,
900 cmap, colorbar, file,
901 show, dpi, fig_size, ax)
902 elif _p.engine == "plotly": 902 ↛ 913line 902 didn't jump to line 913, because the condition on line 902 was never false
903 _check_incompatible_args_plotly(dpi, fig_size, ax)
904 fig = _plot_plotly(sys, num_lead_cells, unit,
905 site_symbol, site_size,
906 site_color, site_edgecolor, site_lw,
907 hop_color, hop_lw,
908 lead_site_symbol, lead_site_size, lead_color,
909 lead_site_edgecolor, lead_site_lw,
910 lead_hop_lw, pos_transform,
911 cmap, colorbar, file,
912 show)
913 elif _p.engine is None:
914 raise RuntimeError("Cannot use plot() without a plotting lib installed")
915 else:
916 raise RuntimeError("plot() does not support engine '{}'".format(_p.engine))
918 _maybe_output_fig(fig, file=file, show=show)
920 return fig
922def _resize_to_dim(array, dim):
923 if array.shape[1] != dim:
924 ar = np.zeros((len(array), dim), dtype=float)
925 ar[:, : min(dim, array.shape[1])] = array[
926 :, : min(dim, array.shape[1])]
927 return ar
928 else:
929 return array
932def _check_length(name, loc):
933 value = loc[name]
934 if name in ('site_size', 'site_lw') and isinstance(value, tuple): 934 ↛ 935line 934 didn't jump to line 935, because the condition on line 934 was never true
935 raise TypeError('{0} may not be a tuple, use list or '
936 'array instead.'.format(name))
937 if isinstance(value, (str, tuple)):
938 return
939 try:
940 if len(value) != loc['n_syst_sites']: 940 ↛ 941line 940 didn't jump to line 941, because the condition on line 940 was never true
941 raise ValueError('Length of {0} is not equal to number of '
942 'system sites.'.format(name))
943 except TypeError:
944 pass
946# make all specs proper: either constant or lists/np.arrays:
947def _make_proper_site_spec(spec_name, spec, syst, sites, fancy_indexing=False):
948 if _p.isarray(spec) and isinstance(syst, builder.Builder):
949 raise TypeError('{} cannot be an array when plotting'
950 ' a Builder; use a function instead.'
951 .format(spec_name))
952 if callable(spec):
953 spec = [spec(i[0]) for i in sites if i[1] is None]
954 if (fancy_indexing and _p.isarray(spec) 954 ↛ 956line 954 didn't jump to line 956, because the condition on line 954 was never true
955 and not isinstance(spec, np.ndarray)):
956 try:
957 spec = np.asarray(spec)
958 except:
959 spec = np.asarray(spec, dtype='object')
960 return spec
962def _make_proper_hop_spec(spec, hops, fancy_indexing=False):
963 if callable(spec):
964 spec = [spec(*i[0]) for i in hops if i[1] is None]
965 if (fancy_indexing and _p.isarray(spec) 965 ↛ 967line 965 didn't jump to line 967, because the condition on line 965 was never true
966 and not isinstance(spec, np.ndarray)):
967 try:
968 spec = np.asarray(spec)
969 except:
970 spec = np.asarray(spec, dtype='object')
971 return spec
973def _plot_plotly(sys, num_lead_cells, unit,
974 site_symbol, site_size,
975 site_color, site_edgecolor, site_lw,
976 hop_color, hop_lw,
977 lead_site_symbol, lead_site_size, lead_color,
978 lead_site_edgecolor, lead_site_lw,
979 lead_hop_lw, pos_transform,
980 cmap, colorbar, file,
981 show, fig=None):
983 if unit is None: 983 ↛ 986line 983 didn't jump to line 986, because the condition on line 983 was never false
984 unit = 'pt'
986 syst = sys # for naming consistency inside function bodies
987 # Generate data.
988 sites, lead_sites_slcs = sys_leads_sites(syst, num_lead_cells)
989 n_syst_sites = sum(i[1] is None for i in sites)
990 sites_pos = sys_leads_pos(syst, sites)
991 hops, lead_hops_slcs = sys_leads_hoppings(syst, num_lead_cells)
992 n_syst_hops = sum(i[1] is None for i in hops)
993 end_pos, start_pos = sys_leads_hopping_pos(syst, hops)
995 loc = locals()
997 for name in ['site_symbol', 'site_size', 'site_color', 'site_edgecolor',
998 'site_lw']:
999 _check_length(name, loc)
1001 # Apply transformations to the data
1002 if pos_transform is not None:
1003 sites_pos = np.apply_along_axis(pos_transform, 1, sites_pos)
1004 end_pos = np.apply_along_axis(pos_transform, 1, end_pos)
1005 start_pos = np.apply_along_axis(pos_transform, 1, start_pos)
1007 dim = 3 if (sites_pos.shape[1] == 3) else 2
1009 sites_pos = _resize_to_dim(sites_pos, dim)
1010 end_pos = _resize_to_dim(end_pos, dim)
1011 start_pos = _resize_to_dim(start_pos, dim)
1013 # Determine the reference length.
1014 if unit != 'pt': 1014 ↛ 1015line 1014 didn't jump to line 1015, because the condition on line 1014 was never true
1015 raise RuntimeError('Plotly engine currently only supports '
1016 'the pt symbol size unit')
1018 site_symbol = _make_proper_site_spec('site_symbol', site_symbol, syst, sites)
1019 if site_symbol is None: site_symbol = defaults['site_symbol'][dim]
1020 # separate different symbols (not done in 3D, the separation
1021 # would mess up sorting)
1022 if (_p.isarray(site_symbol) and dim != 3 and 1022 ↛ 1024line 1022 didn't jump to line 1024, because the condition on line 1022 was never true
1023 (len(site_symbol) != 3 or site_symbol[0] not in ('p', 'P'))):
1024 symbol_dict = defaultdict(list)
1025 for i, symbol in enumerate(site_symbol):
1026 symbol_dict[symbol].append(i)
1027 symbol_slcs = []
1028 for symbol, indx in symbol_dict.items():
1029 symbol_slcs.append((symbol, np.array(indx)))
1030 fancy_indexing = True
1031 else:
1032 symbol_slcs = [(site_symbol, slice(n_syst_sites))]
1033 fancy_indexing = False
1035 if site_color is None:
1036 cycle = _color_cycle()
1037 if isinstance(syst, (builder.FiniteSystem, builder.InfiniteSystem)):
1038 # Skipping the leads for brevity.
1039 families = sorted({site.family for site in syst.sites})
1040 color_mapping = dict(zip(families, cycle))
1041 def site_color(site):
1042 return color_mapping[syst.sites[site].family]
1043 elif isinstance(syst, builder.Builder): 1043 ↛ 1050line 1043 didn't jump to line 1050, because the condition on line 1043 was never false
1044 families = sorted({site[0].family for site in sites})
1045 color_mapping = dict(zip(families, cycle))
1046 def site_color(site):
1047 return color_mapping[site.family]
1048 else:
1049 # Unknown finalized system, no sites access.
1050 site_color = defaults['site_color'][dim]
1052 site_size = _make_proper_site_spec('site_size',site_size, syst, sites, fancy_indexing)
1053 site_color = _make_proper_site_spec('site_color',site_color, syst, sites, fancy_indexing)
1054 site_edgecolor = _make_proper_site_spec('site_edgecolor',site_edgecolor, syst, sites,
1055 fancy_indexing)
1056 site_lw = _make_proper_site_spec('site_lw',site_lw, syst, sites, fancy_indexing)
1058 hop_color = _make_proper_hop_spec(hop_color, hops)
1059 hop_lw = _make_proper_hop_spec(hop_lw, hops)
1061 # Choose defaults depending on dimension, if None was given
1062 if site_size is None: site_size = defaults['site_size'][dim]
1063 if site_edgecolor is None: 1063 ↛ 1065line 1063 didn't jump to line 1065, because the condition on line 1063 was never false
1064 site_edgecolor = defaults['site_edgecolor'][dim]
1065 if site_lw is None: site_lw = defaults['site_lw'][dim]
1067 if hop_color is None: hop_color = defaults['hop_color'][dim]
1068 if hop_lw is None: hop_lw = defaults['hop_lw'][dim]
1070 if len(symbol_slcs) > 1: 1070 ↛ 1071line 1070 didn't jump to line 1071, because the condition on line 1070 was never true
1071 try:
1072 if site_color.ndim == 1 and len(site_color) == n_syst_sites:
1073 site_color = np.asarray(site_color, dtype=float)
1074 except:
1075 pass
1077 # take spec also for lead, if it's not a list/array, default, otherwise
1078 if lead_site_symbol is None: 1078 ↛ 1081line 1078 didn't jump to line 1081, because the condition on line 1078 was never false
1079 lead_site_symbol = (site_symbol if not _p.isarray(site_symbol)
1080 else defaults['site_symbol'][dim])
1081 if lead_site_size is None: 1081 ↛ 1084line 1081 didn't jump to line 1084, because the condition on line 1081 was never false
1082 lead_site_size = (site_size if not _p.isarray(site_size)
1083 else defaults['site_size'][dim])
1084 if lead_color is None: 1084 ↛ 1086line 1084 didn't jump to line 1086, because the condition on line 1084 was never false
1085 lead_color = defaults['lead_color'][dim]
1086 lead_color = _p.matplotlib.colors.colorConverter.to_rgba(lead_color)
1088 if lead_site_edgecolor is None: 1088 ↛ 1091line 1088 didn't jump to line 1091, because the condition on line 1088 was never false
1089 lead_site_edgecolor = (site_edgecolor if not _p.isarray(site_edgecolor)
1090 else defaults['site_edgecolor'][dim])
1091 if lead_site_lw is None: 1091 ↛ 1094line 1091 didn't jump to line 1094, because the condition on line 1091 was never false
1092 lead_site_lw = (site_lw if not _p.isarray(site_lw)
1093 else defaults['site_lw'][dim])
1094 if lead_hop_lw is None: 1094 ↛ 1098line 1094 didn't jump to line 1098, because the condition on line 1094 was never false
1095 lead_hop_lw = (hop_lw if not _p.isarray(hop_lw)
1096 else defaults['hop_lw'][dim])
1098 hop_cmap = None
1099 if not isinstance(cmap, str): 1099 ↛ 1100line 1099 didn't jump to line 1100, because the condition on line 1099 was never true
1100 try:
1101 cmap, hop_cmap = cmap
1102 except TypeError:
1103 pass
1104 # Plot system sites and hoppings
1106 # First plot the nodes (sites) of the graph
1107 assert dim == 2 or dim == 3
1108 site_node_trace, site_edge_trace = [], []
1109 for symbol, slc in symbol_slcs:
1110 site_symbol_plotly = _p.convert_symbol_mpl_plotly(symbol)
1111 if site_symbol_plotly == -1: 1111 ↛ 1115line 1111 didn't jump to line 1115, because the condition on line 1111 was never true
1112 # The kwant documentation supports no symbol as a string argument for site_symbol
1113 # If it evaluates to -1, then the user has specified "no symbol" as the input.
1114 # https://kwant-project.org/doc/1/reference/generated/kwant.plotter.plot
1115 continue
1116 size = site_size[slc] if _p.isarray(site_size) else site_size
1117 col = site_color[slc] if _p.isarray(site_color) else site_color
1118 if _p.isarray(site_edgecolor) or _p.isarray(site_lw): 1118 ↛ 1119line 1118 didn't jump to line 1119, because the condition on line 1118 was never true
1119 raise RuntimeError("Plotly engine not currently support an array "
1120 "of linecolors or linewidths. Please restrict "
1121 "to only a constant (i.e. no function or array)"
1122 " site_edgecolor and site_lw property "
1123 "for the entire plot.")
1124 else:
1125 edgecol = site_edgecolor if not isinstance(site_edgecolor, tuple) \
1126 else _p.convert_colormap_mpl_plotly(*site_edgecolor)
1127 lw = site_lw
1129 if dim == 3:
1130 x, y, z = sites_pos[slc].transpose()
1131 site_node_trace_elem = _p.plotly_graph_objs.Scatter3d(x=x, y=y,
1132 z=z)
1133 site_node_trace_elem.marker.symbol = _p.convert_symbol_mpl_plotly_3d(
1134 symbol)
1135 else:
1136 x, y = sites_pos[slc].transpose()
1137 site_node_trace_elem = _p.plotly_graph_objs.Scatter(x=x, y=y)
1138 site_node_trace_elem.marker.symbol = _p.convert_symbol_mpl_plotly(
1139 symbol)
1141 site_node_trace_elem.mode = 'markers'
1142 site_node_trace_elem.hoverinfo = 'none'
1143 site_node_trace_elem.marker.showscale = False
1144 site_node_trace_elem.marker.colorscale = \
1145 _p.convert_cmap_list_mpl_plotly(cmap)
1146 site_node_trace_elem.marker.reversescale = False
1147 marker_color = col if not isinstance(col, tuple) \
1148 else _p.convert_colormap_mpl_plotly(*col)
1149 site_node_trace_elem.marker.color = marker_color
1150 site_node_trace_elem.marker.size = \
1151 _p.convert_site_size_mpl_plotly(size,
1152 defaults['plotly_site_size_reference'])
1154 site_node_trace_elem.line.width = lw
1155 site_node_trace_elem.line.color = edgecol
1156 site_node_trace_elem.showlegend = False
1158 site_node_trace.append(site_node_trace_elem)
1160 # Now plot the edges (hops) of the graph
1161 end, start = end_pos[: n_syst_hops], start_pos[: n_syst_hops]
1163 if dim == 3:
1164 x0, y0, z0 = end.transpose()
1165 x1, y1, z1 = start.transpose()
1166 nones = [None] * len(x0)
1167 site_edge_trace_elem = _p.plotly_graph_objs.Scatter3d(
1168 x=np.array([x0, x1, nones]).transpose().flatten(),
1169 y=np.array([y0, y1, nones]).transpose().flatten(),
1170 z=np.array([z0, z1, nones]).transpose().flatten()
1171 )
1172 else:
1173 x0, y0 = end.transpose()
1174 x1, y1 = start.transpose()
1175 nones = [None] * len(x0)
1176 site_edge_trace_elem = _p.plotly_graph_objs.Scatter(
1177 x=np.array([x0, x1, nones]).transpose().flatten(),
1178 y=np.array([y0, y1, nones]).transpose().flatten()
1179 )
1181 if _p.isarray(hop_color) or _p.isarray(hop_lw): 1181 ↛ 1182line 1181 didn't jump to line 1182, because the condition on line 1181 was never true
1182 raise RuntimeError("Plotly engine not currently support an array "
1183 "of linecolors or linewidths. Please restrict "
1184 "to only a constant (i.e. no function or array)"
1185 " hop_color and hop_lw property "
1186 "for the entire plot.")
1187 site_edge_trace_elem.line.width = hop_lw
1188 site_edge_trace_elem.line.color = hop_color
1189 site_edge_trace_elem.hoverinfo = 'none'
1190 site_edge_trace_elem.showlegend = False
1191 site_edge_trace_elem.mode = 'lines'
1192 site_edge_trace.append(site_edge_trace_elem)
1194 # Plot lead sites and edges
1196 lead_node_trace, lead_edge_trace = [], []
1197 for sites_slc, hops_slc in zip(lead_sites_slcs, lead_hops_slcs):
1198 lead_site_colors = np.array([i[2] for i in sites[sites_slc]],
1199 dtype=float)
1200 if dim == 3:
1202 x, y, z = sites_pos[sites_slc].transpose()
1203 lead_node_trace_elem = _p.plotly_graph_objs.Scatter3d(x=x, y=y,
1204 z=z)
1205 lead_node_trace_elem.marker.symbol = \
1206 _p.convert_symbol_mpl_plotly_3d(lead_site_symbol)
1207 else:
1208 x, y = sites_pos[sites_slc].transpose()
1209 lead_node_trace_elem = _p.plotly_graph_objs.Scatter(x=x, y=y)
1210 lead_site_symbol_plotly = _p.convert_symbol_mpl_plotly(lead_site_symbol)
1211 if lead_site_symbol_plotly == -1: 1211 ↛ 1215line 1211 didn't jump to line 1215, because the condition on line 1211 was never true
1212 # The kwant documentation supports no symbol as a string argument for site_symbol
1213 # If it evaluates to -1, then the user has specified "no symbol" as the input.
1214 # https://kwant-project.org/doc/1/reference/generated/kwant.plotter.plot
1215 continue
1216 lead_node_trace_elem.marker.symbol = lead_site_symbol_plotly
1218 lead_node_trace_elem.mode = 'markers'
1219 lead_node_trace_elem.hoverinfo = 'none'
1220 lead_node_trace_elem.showlegend = False
1221 lead_node_trace_elem.marker.showscale = False
1222 lead_node_trace_elem.marker.reversescale = False
1223 lead_node_trace_elem.marker.color = lead_site_colors
1224 lead_node_trace_elem.marker.colorscale = \
1225 _p.convert_lead_cmap_mpl_plotly(lead_color,
1226 [1, 1, 1, lead_color[3]])
1227 lead_node_trace_elem.marker.size = _p.convert_site_size_mpl_plotly(
1228 lead_site_size,
1229 defaults['plotly_site_size_reference'])
1231 if _p.isarray(lead_site_lw) or _p.isarray(lead_site_edgecolor): 1231 ↛ 1232line 1231 didn't jump to line 1232, because the condition on line 1231 was never true
1232 raise RuntimeError("Plotly engine not currently support an array "
1233 "of linecolors or linewidths. Please restrict "
1234 "to only a constant (i.e. no function or array) "
1235 "lead_site_lw and lead_site_edgecolor property "
1236 "for the entire plot.")
1237 lead_node_trace_elem.line.width = lead_site_lw
1238 lead_node_trace_elem.line.color = lead_site_edgecolor
1240 if lead_node_trace_elem: 1240 ↛ 1243line 1240 didn't jump to line 1243, because the condition on line 1240 was never false
1241 lead_node_trace.append(lead_node_trace_elem)
1243 lead_hop_colors = np.array([i[2] for i in hops[hops_slc]], dtype=float)
1245 end, start = end_pos[hops_slc], start_pos[hops_slc]
1247 if dim == 3:
1248 x0, y0, z0 = end.transpose()
1249 x1, y1, z1 = start.transpose()
1250 nones = [None] * len(x0)
1251 lead_edge_trace_elem = _p.plotly_graph_objs.Scatter3d(
1252 x=np.array([x0, x1, nones]).transpose().flatten(),
1253 y=np.array([y0, y1, nones]).transpose().flatten(),
1254 z=np.array([z0, z1, nones]).transpose().flatten()
1255 )
1257 else:
1258 x0, y0 = end.transpose()
1259 x1, y1 = start.transpose()
1260 nones = [None] * len(x0)
1261 lead_edge_trace_elem = _p.plotly_graph_objs.Scatter(
1262 x=np.array([x0, x1, nones]).transpose().flatten(),
1263 y=np.array([y0, y1, nones]).transpose().flatten()
1264 )
1266 lead_edge_trace_elem.line.width = lead_hop_lw
1267 lead_edge_trace_elem.line.color = _p.convert_colormap_mpl_plotly(*lead_color)
1268 lead_edge_trace_elem.hoverinfo = 'none'
1269 lead_edge_trace_elem.mode = 'lines'
1270 lead_edge_trace_elem.showlegend = False
1272 lead_edge_trace.append(lead_edge_trace_elem)
1274 layout = _p.plotly_graph_objs.Layout(
1275 showlegend=False,
1276 hovermode='closest',
1277 xaxis=dict(showgrid=False, zeroline=False,
1278 showticklabels=True),
1279 yaxis=dict(showgrid=False, zeroline=False,
1280 showticklabels=True))
1281 if fig is None: 1281 ↛ 1288line 1281 didn't jump to line 1288, because the condition on line 1281 was never false
1282 full_trace = list(itertools.chain.from_iterable([site_edge_trace,
1283 site_node_trace, lead_edge_trace,
1284 lead_node_trace]))
1285 fig = _p.plotly_graph_objs.Figure(data=full_trace,
1286 layout=layout)
1287 else:
1288 full_trace = list(itertools.chain.from_iterable([lead_edge_trace,
1289 lead_node_trace]))
1290 for trace in full_trace:
1291 try:
1292 fig.add_trace(trace)
1293 except TypeError:
1294 fig.data += [trace]
1296 return fig
1299def _plot_matplotlib(sys, num_lead_cells, unit,
1300 site_symbol, site_size,
1301 site_color, site_edgecolor, site_lw,
1302 hop_color, hop_lw,
1303 lead_site_symbol, lead_site_size, lead_color,
1304 lead_site_edgecolor, lead_site_lw,
1305 lead_hop_lw, pos_transform,
1306 cmap, colorbar, file,
1307 show, dpi, fig_size, ax):
1309 if unit is None: 1309 ↛ 1312line 1309 didn't jump to line 1312, because the condition on line 1309 was never false
1310 unit = 'nn'
1312 syst = sys # for naming consistency inside function bodies
1313 # Generate data.
1314 sites, lead_sites_slcs = sys_leads_sites(syst, num_lead_cells)
1315 n_syst_sites = sum(i[1] is None for i in sites)
1316 sites_pos = sys_leads_pos(syst, sites)
1317 hops, lead_hops_slcs = sys_leads_hoppings(syst, num_lead_cells)
1318 n_syst_hops = sum(i[1] is None for i in hops)
1319 end_pos, start_pos = sys_leads_hopping_pos(syst, hops)
1321 loc = locals()
1323 for name in ['site_symbol', 'site_size', 'site_color', 'site_edgecolor',
1324 'site_lw']:
1325 _check_length(name, loc)
1327 # Apply transformations to the data
1328 if pos_transform is not None:
1329 sites_pos = np.apply_along_axis(pos_transform, 1, sites_pos)
1330 end_pos = np.apply_along_axis(pos_transform, 1, end_pos)
1331 start_pos = np.apply_along_axis(pos_transform, 1, start_pos)
1333 dim = 3 if (sites_pos.shape[1] == 3) else 2
1334 if dim == 3 and not _p.has3d: 1334 ↛ 1335line 1334 didn't jump to line 1335, because the condition on line 1334 was never true
1335 raise RuntimeError("Installed matplotlib does not support 3d plotting")
1336 sites_pos = _resize_to_dim(sites_pos, dim)
1337 end_pos = _resize_to_dim(end_pos, dim)
1338 start_pos = _resize_to_dim(start_pos, dim)
1340 # Determine the reference length.
1341 if unit == 'pt': 1341 ↛ 1342line 1341 didn't jump to line 1342, because the condition on line 1341 was never true
1342 reflen = None
1343 elif unit == 'nn': 1343 ↛ 1366line 1343 didn't jump to line 1366, because the condition on line 1343 was never false
1344 if n_syst_hops:
1345 # If hoppings are present use their lengths to determine the
1346 # minimal one.
1347 distances = end_pos - start_pos
1348 else:
1349 # If no hoppings are present, use for the same purpose distances
1350 # from ten randomly selected points to the remaining points in the
1351 # system.
1352 points = _sample_array(sites_pos, 10).T
1353 distances = (sites_pos.reshape(1, -1, dim) -
1354 points.reshape(-1, 1, dim)).reshape(-1, dim)
1355 distances = np.sort(np.sum(distances**2, axis=1))
1356 # Then check if distances are present that are way shorter than the
1357 # longest one. Then take first distance longer than these short
1358 # ones. This heuristic will fail for too large systems, or systems with
1359 # hoppings that vary by orders and orders of magnitude, but for sane
1360 # cases it will work.
1361 long_dist_coord = np.searchsorted(distances, 1e-16 * distances[-1])
1362 reflen = sqrt(distances[long_dist_coord])
1364 else:
1365 # The last allowed value is float-compatible.
1366 try:
1367 reflen = float(unit)
1368 except:
1369 raise ValueError('Invalid value of unit argument.')
1371 site_symbol = _make_proper_site_spec('site_symbol', site_symbol, syst, sites)
1372 if site_symbol is None: site_symbol = defaults['site_symbol'][dim]
1373 # separate different symbols (not done in 3D, the separation
1374 # would mess up sorting)
1375 if (_p.isarray(site_symbol) and dim != 3 and 1375 ↛ 1377line 1375 didn't jump to line 1377, because the condition on line 1375 was never true
1376 (len(site_symbol) != 3 or site_symbol[0] not in ('p', 'P'))):
1377 symbol_dict = defaultdict(list)
1378 for i, symbol in enumerate(site_symbol):
1379 symbol_dict[symbol].append(i)
1380 symbol_slcs = []
1381 for symbol, indx in symbol_dict.items():
1382 symbol_slcs.append((symbol, np.array(indx)))
1383 fancy_indexing = True
1384 else:
1385 symbol_slcs = [(site_symbol, slice(n_syst_sites))]
1386 fancy_indexing = False
1388 if site_color is None:
1389 cycle = _color_cycle()
1390 if builder.is_system(syst):
1391 # Skipping the leads for brevity.
1392 families = sorted({site.family for site in syst.sites})
1393 color_mapping = dict(zip(families, cycle))
1394 def site_color(site):
1395 return color_mapping[syst.sites[site].family]
1396 elif isinstance(syst, builder.Builder): 1396 ↛ 1403line 1396 didn't jump to line 1403, because the condition on line 1396 was never false
1397 families = sorted({site[0].family for site in sites})
1398 color_mapping = dict(zip(families, cycle))
1399 def site_color(site):
1400 return color_mapping[site.family]
1401 else:
1402 # Unknown finalized system, no sites access.
1403 site_color = defaults['site_color'][dim]
1405 site_size = _make_proper_site_spec('site_size', site_size, syst, sites, fancy_indexing)
1406 site_color = _make_proper_site_spec('site_color', site_color, syst, sites, fancy_indexing)
1407 site_edgecolor = _make_proper_site_spec('site_edgecolor', site_edgecolor, syst, sites, fancy_indexing)
1408 site_lw = _make_proper_site_spec('site_lw', site_lw, syst, sites, fancy_indexing)
1410 hop_color = _make_proper_hop_spec(hop_color, hops)
1411 hop_lw = _make_proper_hop_spec(hop_lw, hops)
1413 # Choose defaults depending on dimension, if None was given
1414 if site_size is None: site_size = defaults['site_size'][dim]
1415 if site_edgecolor is None: 1415 ↛ 1417line 1415 didn't jump to line 1417, because the condition on line 1415 was never false
1416 site_edgecolor = defaults['site_edgecolor'][dim]
1417 if site_lw is None: site_lw = defaults['site_lw'][dim]
1419 if hop_color is None: hop_color = defaults['hop_color'][dim]
1420 if hop_lw is None: hop_lw = defaults['hop_lw'][dim]
1422 # if symbols are split up into different collections,
1423 # the colormapping will fail without normalization
1424 norm = None
1425 if len(symbol_slcs) > 1: 1425 ↛ 1426line 1425 didn't jump to line 1426, because the condition on line 1425 was never true
1426 try:
1427 if site_color.ndim == 1 and len(site_color) == n_syst_sites:
1428 site_color = np.asarray(site_color, dtype=float)
1429 norm = _p.matplotlib.colors.Normalize(site_color.min(),
1430 site_color.max())
1431 except:
1432 pass
1434 # take spec also for lead, if it's not a list/array, default, otherwise
1435 if lead_site_symbol is None: 1435 ↛ 1438line 1435 didn't jump to line 1438, because the condition on line 1435 was never false
1436 lead_site_symbol = (site_symbol if not _p.isarray(site_symbol)
1437 else defaults['site_symbol'][dim])
1438 if lead_site_size is None: 1438 ↛ 1441line 1438 didn't jump to line 1441, because the condition on line 1438 was never false
1439 lead_site_size = (site_size if not _p.isarray(site_size)
1440 else defaults['site_size'][dim])
1441 if lead_color is None: 1441 ↛ 1443line 1441 didn't jump to line 1443, because the condition on line 1441 was never false
1442 lead_color = defaults['lead_color'][dim]
1443 lead_color = _p.matplotlib.colors.colorConverter.to_rgba(lead_color)
1445 if lead_site_edgecolor is None: 1445 ↛ 1448line 1445 didn't jump to line 1448, because the condition on line 1445 was never false
1446 lead_site_edgecolor = (site_edgecolor if not _p.isarray(site_edgecolor)
1447 else defaults['site_edgecolor'][dim])
1448 if lead_site_lw is None: 1448 ↛ 1451line 1448 didn't jump to line 1451, because the condition on line 1448 was never false
1449 lead_site_lw = (site_lw if not _p.isarray(site_lw)
1450 else defaults['site_lw'][dim])
1451 if lead_hop_lw is None: 1451 ↛ 1455line 1451 didn't jump to line 1455, because the condition on line 1451 was never false
1452 lead_hop_lw = (hop_lw if not _p.isarray(hop_lw)
1453 else defaults['hop_lw'][dim])
1455 hop_cmap = None
1456 if not isinstance(cmap, str): 1456 ↛ 1457line 1456 didn't jump to line 1457, because the condition on line 1456 was never true
1457 try:
1458 cmap, hop_cmap = cmap
1459 except TypeError:
1460 pass
1462 # make a new figure unless axes specified
1463 if not ax: 1463 ↛ 1474line 1463 didn't jump to line 1474, because the condition on line 1463 was never false
1464 fig = _make_figure(dpi, fig_size, use_pyplot=(file is None))
1465 if dim == 2:
1466 ax = fig.add_subplot(1, 1, 1, aspect='equal')
1467 ax.set_xmargin(0.05)
1468 ax.set_ymargin(0.05)
1469 else:
1470 warnings.filterwarnings('ignore', message=r'.*rotation.*')
1471 ax = fig.add_subplot(1, 1, 1, projection='3d')
1472 warnings.resetwarnings()
1473 else:
1474 fig = None
1476 # plot system sites and hoppings
1477 for symbol, slc in symbol_slcs:
1478 size = site_size[slc] if _p.isarray(site_size) else site_size
1479 col = site_color[slc] if _p.isarray(site_color) else site_color
1480 edgecol = (site_edgecolor[slc] if _p.isarray(site_edgecolor) else
1481 site_edgecolor)
1482 lw = site_lw[slc] if _p.isarray(site_lw) else site_lw
1484 symbol_coll = symbols(ax, sites_pos[slc], size=size,
1485 reflen=reflen, symbol=symbol,
1486 facecolor=col, edgecolor=edgecol,
1487 linewidth=lw, cmap=cmap, norm=norm, zorder=2)
1489 end, start = end_pos[: n_syst_hops], start_pos[: n_syst_hops]
1490 line_coll = lines(ax, end, start, reflen, hop_color, linewidths=hop_lw,
1491 zorder=1, cmap=hop_cmap)
1493 # plot lead sites and hoppings
1494 norm = _p.matplotlib.colors.Normalize(-0.5, num_lead_cells - 0.5)
1495 cmap_from_list = _p.matplotlib.colors.LinearSegmentedColormap.from_list
1496 lead_cmap = cmap_from_list(None, [lead_color, (1, 1, 1, lead_color[3])])
1498 for sites_slc, hops_slc in zip(lead_sites_slcs, lead_hops_slcs):
1499 lead_site_colors = np.array([i[2] for i in sites[sites_slc]],
1500 dtype=float)
1502 # Note: the previous version of the code had in addition this
1503 # line in the 3D case:
1504 # lead_site_colors = 1 / np.sqrt(1. + lead_site_colors)
1505 symbols(ax, sites_pos[sites_slc], size=lead_site_size, reflen=reflen,
1506 symbol=lead_site_symbol, facecolor=lead_site_colors,
1507 edgecolor=lead_site_edgecolor, linewidth=lead_site_lw,
1508 cmap=lead_cmap, zorder=2, norm=norm)
1510 lead_hop_colors = np.array([i[2] for i in hops[hops_slc]], dtype=float)
1512 # Note: the previous version of the code had in addition this
1513 # line in the 3D case:
1514 # lead_hop_colors = 1 / np.sqrt(1. + lead_hop_colors)
1515 end, start = end_pos[hops_slc], start_pos[hops_slc]
1516 lines(ax, end, start, reflen, lead_hop_colors, linewidths=lead_hop_lw,
1517 cmap=lead_cmap, norm=norm, zorder=1)
1519 min_ = np.min(sites_pos, 0)
1520 max_ = np.max(sites_pos, 0)
1521 m = (min_ + max_) / 2
1522 if dim == 2:
1523 w = np.max([(max_ - min_) / 2, (reflen, reflen)], axis=0)
1524 ax.update_datalim((m - w, m + w))
1525 ax.autoscale_view(tight=True)
1526 else:
1527 # make axis limits the same in all directions
1528 # (3D only works decently for equal aspect ratio. Since
1529 # this doesn't work out of the box in mplot3d, this is a
1530 # workaround)
1531 w = np.max(max_ - min_) / 2
1532 ax.auto_scale_xyz(*[(i - w, i + w) for i in m], had_data=True)
1534 # add separate colorbars for symbols and hoppings if ncessary
1535 if symbol_coll.get_array() is not None and colorbar and fig is not None:
1536 fig.colorbar(symbol_coll)
1537 if line_coll.get_array() is not None and colorbar and fig is not None:
1538 fig.colorbar(line_coll)
1540 return fig
1543def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3):
1544 """Interpolate a scalar function in vicinity of given points.
1546 Create a masked array corresponding to interpolated values of the function
1547 at points lying not further than a certain distance from the original
1548 data points provided.
1550 Parameters
1551 ----------
1552 coords : np.ndarray
1553 An array with site coordinates.
1554 values : np.ndarray
1555 An array with the values from which the interpolation should be built.
1556 a : float, optional
1557 Reference length. If not given, it is determined as a typical
1558 nearest neighbor distance.
1559 method : string, optional
1560 Passed to ``scipy.interpolate.griddata``: "nearest" (default), "linear",
1561 or "cubic"
1562 oversampling : integer, optional
1563 Number of pixels per reference length. Defaults to 3.
1565 Returns
1566 -------
1567 array : 2d NumPy array
1568 The interpolated values.
1569 min, max : vectors
1570 The real-space coordinates of the two extreme ([0, 0] and [-1, -1])
1571 points of ``array``.
1573 Notes
1574 -----
1575 - `min` and `max` are chosen such that when plotting a system on a square
1576 lattice and `oversampling` is set to an odd integer, each site will lie
1577 exactly at the center of a pixel of the output array.
1579 - When plotting a system on a square lattice and `method` is "nearest", it
1580 makes sense to set `oversampling` to ``1``. Then, each site will
1581 correspond to exactly one pixel in the resulting array.
1582 """
1583 # Build the bounding box.
1584 cmin, cmax = coords.min(0), coords.max(0)
1586 tree = spatial.cKDTree(coords)
1588 # Select 10 sites to compare -- comparing them all is too costly.
1589 points = _sample_array(coords, 10)
1590 min_dist = np.min(tree.query(points, 2)[0][:, 1])
1591 if min_dist < 1e-6 * np.linalg.norm(cmax - cmin):
1592 warnings.warn("Some sites have nearly coinciding positions, "
1593 "interpolation may be confusing.",
1594 RuntimeWarning, stacklevel=2)
1596 if a is None:
1597 a = min_dist
1599 if a < 1e-6 * np.linalg.norm(cmax - cmin):
1600 raise ValueError("The reference distance a is too small.")
1602 if len(coords) != len(values): 1602 ↛ 1603line 1602 didn't jump to line 1603, because the condition on line 1602 was never true
1603 raise ValueError("The number of sites doesn't match the number of "
1604 "provided values.")
1606 shape = (((cmax - cmin) / a + 1) * oversampling).round()
1607 delta = 0.5 * (oversampling - 1) * a / oversampling
1608 cmin -= delta
1609 cmax += delta
1610 dims = tuple(slice(cmin[i], cmax[i], 1j * shape[i]) for i in
1611 range(len(cmin)))
1612 grid = tuple(np.ogrid[dims])
1613 img = interpolate.griddata(coords, values, grid, method)
1614 img = img.astype(np.float_)
1615 mask = np.mgrid[dims].reshape(len(cmin), -1).T
1616 # The numerical values in the following line are optimized for the common
1617 # case of a square lattice:
1618 # * 0.99 makes sure that non-masked pixels and sites correspond 1-by-1 to
1619 # each other when oversampling == 1.
1620 # * 0.4 (which is just below sqrt(2) - 1) makes tree.query() exact.
1621 mask = tree.query(mask, eps=0.4)[0] > 0.99 * a
1623 masked_result_array = np.ma.masked_array(img, mask)
1625 try:
1626 if _p.engine != "matplotlib":
1627 result_array = masked_result_array.filled(np.NaN)
1628 else:
1629 result_array = masked_result_array
1630 except AttributeError:
1631 result_array = masked_result_array
1633 return result_array, img, cmin, cmax
1636def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
1637 method='nearest', oversampling=3, num_lead_cells=0, file=None,
1638 show=True, dpi=None, fig_size=None, ax=None, pos_transform=None,
1639 background='#e0e0e0'):
1640 """Show interpolated map of a function defined for the sites of a system.
1642 Create a pixmap representation of a function of the sites of a system by
1643 calling `~kwant.plotter.mask_interpolate` and show this pixmap using
1644 matplotlib.
1646 This function is similar to `~kwant.plotter.density`, but is more suited
1647 to the case where you want site-level resolution of the quantity that
1648 you are plotting. If your system has many sites you may get more appealing
1649 plots by using `~kwant.plotter.density`.
1651 Parameters
1652 ----------
1653 sys : kwant.system.FiniteSystem or kwant.builder.Builder
1654 The system for whose sites `value` is to be plotted.
1655 value : function or list
1656 Function which takes a site and returns a value if the system is a
1657 builder, or a list of function values for each system site of the
1658 finalized system.
1659 colorbar : bool, optional
1660 Whether to show a color bar if numerical data has to be plotted.
1661 Defaults to `True`. If `ax` is provided, the colorbar is never plotted.
1662 cmap : ``matplotlib`` color map or `None`
1663 The color map used for sites and optionally hoppings, if `None`,
1664 ``matplotlib`` default is used.
1665 vmin : float, optional
1666 The lower saturation limit for the colormap; values returned by
1667 `value` which are smaller than this will saturate
1668 vmax : float, optional
1669 The upper saturation limit for the colormap; valued returned by
1670 `value` which are larger than this will saturate
1671 a : float, optional
1672 Reference length. If not given, it is determined as a typical
1673 nearest neighbor distance.
1674 method : string, optional
1675 Passed to ``scipy.interpolate.griddata``: "nearest" (default), "linear",
1676 or "cubic"
1677 oversampling : integer, optional
1678 Number of pixels per reference length. Defaults to 3.
1679 num_lead_cells : integer, optional
1680 number of lead unit cells that should be plotted to indicate
1681 the position of leads. Defaults to 0.
1682 file : string or file object or `None`
1683 The output file. If `None`, output will be shown instead.
1684 show : bool
1685 Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
1686 to be shown immediately. Defaults to `True`.
1687 ax : ``matplotlib.axes.Axes`` instance or `None`
1688 If `ax` is not `None`, no new figure is created, but the plot is done
1689 within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
1690 and `fig_size` are ignored.
1691 pos_transform : function or `None`
1692 Transformation to be applied to the site position.
1693 background : matplotlib color spec
1694 Areas without sites are filled with this color.
1696 Returns
1697 -------
1698 fig : matplotlib figure
1699 A figure with the output if `ax` is not set, else None.
1701 Notes
1702 -----
1703 - When plotting a system on a square lattice and `method` is "nearest", it
1704 makes sense to set `oversampling` to ``1``. Then, each site will
1705 correspond to exactly one pixel.
1707 See Also
1708 --------
1709 kwant.plotter.density
1710 """
1712 if not (_p.mpl_available or _p.plotly_available): 1712 ↛ 1713line 1712 didn't jump to line 1713, because the condition on line 1712 was never true
1713 raise RuntimeError("matplotlib was not found, but is required "
1714 "for map()")
1716 syst = sys # for naming consistency inside function bodies
1717 sites = sys_leads_sites(syst, 0)[0]
1718 coords = sys_leads_pos(syst, sites)
1720 if pos_transform is not None:
1721 coords = np.apply_along_axis(pos_transform, 1, coords)
1723 if coords.shape[1] != 2:
1724 raise ValueError('Only 2D systems can be plotted this way.')
1726 if callable(value):
1727 value = [value(site[0]) for site in sites]
1728 else:
1729 if not system.is_finite(syst):
1730 raise ValueError('List of values is only allowed as input '
1731 'for finalized systems.')
1732 value = np.array(value)
1733 with _common.reraise_warnings():
1734 img, unmasked_data, _min, _max = mask_interpolate(coords, value,
1735 a, method, oversampling)
1737 # Calculate the min/max bounds for the colormap.
1738 # User-provided values take precedence.
1739 if _p.engine != "matplotlib":
1740 unmasked_data = img.ravel()
1741 else:
1742 unmasked_data = img[~img.mask].data.flatten()
1743 unmasked_data = unmasked_data[~np.isnan(unmasked_data)]
1744 new_vmin, new_vmax = percentile_bound(unmasked_data, vmin, vmax)
1745 overflow_pct = 100 * np.sum(unmasked_data > new_vmax) / len(unmasked_data)
1746 underflow_pct = 100 * np.sum(unmasked_data < new_vmin) / len(unmasked_data)
1747 if (vmin is None and underflow_pct) or (vmax is None and overflow_pct): 1747 ↛ 1748line 1747 didn't jump to line 1748, because the condition on line 1747 was never true
1748 msg = (
1749 'The plotted data contains ',
1750 '{:.2f}% of values overflowing upper limit {:g} '
1751 .format(overflow_pct, new_vmax)
1752 if overflow_pct > 0 else '',
1753 'and ' if overflow_pct > 0 and underflow_pct > 0 else '',
1754 '{:.2f}% of values underflowing lower limit {:g} '
1755 .format(underflow_pct, new_vmin)
1756 if underflow_pct > 0 else '',
1757 )
1758 warnings.warn(''.join(msg), RuntimeWarning, stacklevel=2)
1759 vmin, vmax = new_vmin, new_vmax
1761 if _p.engine == "matplotlib":
1762 fig = _map_matplotlib(syst, img, colorbar, _max, _min, vmin, vmax,
1763 overflow_pct, underflow_pct, cmap, num_lead_cells,
1764 background, dpi, fig_size, ax, file)
1765 elif _p.engine == "plotly": 1765 ↛ 1769line 1765 didn't jump to line 1769, because the condition on line 1765 was never false
1766 fig = _map_plotly(syst, img, colorbar, _max, _min, vmin, vmax,
1767 overflow_pct, underflow_pct, cmap, num_lead_cells,
1768 background)
1769 elif _p.engine is None:
1770 raise RuntimeError("Cannot use map() without a plotting lib installed")
1771 else:
1772 raise RuntimeError("map() does not support engine '{}'".format(_p.engine))
1774 _maybe_output_fig(fig, file=file, show=show)
1776 return fig
1779def _map_plotly(syst, img, colorbar, _max, _min, vmin, vmax, overflow_pct,
1780 underflow_pct, cmap, num_lead_cells, background):
1782 border = 0.5 * (_max - _min) / (np.asarray(img.shape) - 1)
1783 _min -= border
1784 _max += border
1786 if cmap is None:
1787 cmap = _p.kwant_red_plotly
1789 img = img.T
1790 contour_object = _p.plotly_graph_objs.Heatmap()
1791 contour_object.z = img
1792 contour_object.x = np.linspace(_min[0],_max[0],img.shape[0])
1793 contour_object.y = np.linspace(_min[1],_max[1],img.shape[1])
1794 contour_object.zsmooth = False
1795 contour_object.connectgaps = False
1796 cmap = _p.convert_cmap_list_mpl_plotly(cmap)
1797 contour_object.colorscale = cmap
1798 contour_object.zmax = vmax
1799 contour_object.zmin = vmin
1800 contour_object.hoverinfo = 'none'
1802 contour_object.showscale = colorbar
1804 fig = _p.plotly_graph_objs.Figure(data=[contour_object])
1805 fig.layout.plot_bgcolor = background
1806 fig.layout.showlegend = False
1808 if num_lead_cells: 1808 ↛ 1809line 1808 didn't jump to line 1809, because the condition on line 1808 was never true
1809 fig = _plot_plotly(syst, num_lead_cells, site_symbol='no symbol',
1810 hop_lw=0, lead_site_symbol='s',
1811 lead_site_size=0.501, lead_site_lw=0,lead_hop_lw=0,
1812 lead_color='black', colorbar=False, show=False,
1813 fig=fig, unit='pt', site_size=None, site_color=None,
1814 site_edgecolor=None, site_lw=0, hop_color=None,
1815 lead_site_edgecolor=None,pos_transform=None,
1816 cmap=None, file=None)
1818 return fig
1821def _map_matplotlib(syst, img, colorbar, _max, _min, vmin, vmax,
1822 overflow_pct, underflow_pct, cmap, num_lead_cells,
1823 background, dpi, fig_size, ax, file):
1825 border = 0.5 * (_max - _min) / (np.asarray(img.shape) - 1)
1826 _min -= border
1827 _max += border
1828 if ax is None: 1828 ↛ 1832line 1828 didn't jump to line 1832, because the condition on line 1828 was never false
1829 fig = _make_figure(dpi, fig_size, use_pyplot=(file is None))
1830 ax = fig.add_subplot(1, 1, 1, aspect='equal')
1831 else:
1832 fig = None
1834 if cmap is None:
1835 cmap = _p.kwant_red_matplotlib
1837 # Note that we tell imshow to show the array created by mask_interpolate
1838 # faithfully and not to interpolate by itself another time.
1839 image = ax.imshow(img.T, extent=(_min[0], _max[0], _min[1], _max[1]),
1840 origin='lower', interpolation='none', cmap=cmap,
1841 vmin=vmin, vmax=vmax)
1842 if num_lead_cells: 1842 ↛ 1843line 1842 didn't jump to line 1843, because the condition on line 1842 was never true
1843 plot(syst, num_lead_cells, site_symbol='no symbol', hop_lw=0,
1844 lead_site_symbol='s', lead_site_size=0.501, lead_site_lw=0,
1845 lead_hop_lw=0, lead_color='black', colorbar=False, ax=ax)
1847 ax.patch.set_facecolor(background)
1849 if colorbar and fig is not None: 1849 ↛ 1860line 1849 didn't jump to line 1860, because the condition on line 1849 was never false
1850 # Make the colorbar ends pointy if we saturate the colormap
1851 extend = 'neither'
1852 if underflow_pct > 0 and overflow_pct > 0: 1852 ↛ 1853line 1852 didn't jump to line 1853, because the condition on line 1852 was never true
1853 extend = 'both'
1854 elif underflow_pct > 0: 1854 ↛ 1855line 1854 didn't jump to line 1855, because the condition on line 1854 was never true
1855 extend = 'min'
1856 elif overflow_pct > 0: 1856 ↛ 1857line 1856 didn't jump to line 1857, because the condition on line 1856 was never true
1857 extend = 'max'
1858 fig.colorbar(image, extend=extend)
1860 return fig
1863@deprecate_args
1864def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None,
1865 fig_size=None, ax=None, *, params=None):
1866 """Plot band structure of a translationally invariant 1D system.
1868 Parameters
1869 ----------
1870 sys : kwant.system.InfiniteSystem
1871 A system bands of which are to be plotted.
1872 args : tuple, defaults to empty
1873 Positional arguments to pass to the ``hamiltonian`` method.
1874 Deprecated in favor of 'params' (and mutually exclusive with it).
1875 momenta : int or 1D array-like
1876 Either a number of sampling points on the interval [-pi, pi], or an
1877 array of points at which the band structure has to be evaluated.
1878 file : string or file object or `None`
1879 The output file. If `None`, output will be shown instead. If plotly is
1880 selected as the engine, the filename has to end with a html extension.
1881 show : bool
1882 For matplotlib engine, whether ``matplotlib.pyplot.show()`` is to be
1883 called, and the output is to be shown immediately.
1884 For the plotly engine, a call to ``iplot(fig)`` is made if
1885 show is True.
1886 Defaults to `True` for both engines.
1887 dpi : float
1888 Number of pixels per inch. If not set the ``matplotlib`` default is
1889 used.
1890 Only for matplotlib engine. If the plotly engine is selected and
1891 this argument is not None, then a RuntimeError will be triggered.
1892 fig_size : tuple
1893 Figure size `(width, height)` in inches. If not set, the default
1894 ``matplotlib`` value is used.
1895 Only for matplotlib engine. If the plotly engine is selected and
1896 this argument is not None, then a RuntimeError will be triggered.
1897 ax : ``matplotlib.axes.Axes`` instance or `None`
1898 If `ax` is not `None`, no new figure is created, but the plot is done
1899 within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
1900 and `fig_size` are ignored.
1901 Only for matplotlib engine. If the plotly engine is selected and
1902 this argument is not None, then a RuntimeError will be triggered.
1903 params : dict, optional
1904 Dictionary of parameter names and their values. Mutually exclusive
1905 with 'args'.
1907 Returns
1908 -------
1909 fig : matplotlib figure or plotly Figure object
1910 A figure with the output if `ax` is not set, else None.
1912 Notes
1913 -----
1914 See `~kwant.physics.Bands` for the calculation of dispersion without plotting.
1915 """
1917 syst = sys # for naming consistency inside function bodies
1919 if _p.plotly_available: 1919 ↛ 1924line 1919 didn't jump to line 1924, because the condition on line 1919 was never false
1920 if _p.engine == "plotly":
1921 _check_incompatible_args_plotly(dpi, fig_size, ax)
1924 _common.ensure_isinstance(syst, (system.InfiniteSystem, system.InfiniteVectorizedSystem))
1926 momenta = np.array(momenta)
1927 if momenta.ndim != 1:
1928 momenta = np.linspace(-np.pi, np.pi, momenta)
1930 # expand out the contents of 'physics.Bands' to get the H(k),
1931 # because 'spectrum' already does the diagonalisation.
1932 ham = syst.cell_hamiltonian(args, params=params)
1933 if not np.allclose(ham, ham.conjugate().transpose()): 1933 ↛ 1934line 1933 didn't jump to line 1934, because the condition on line 1933 was never true
1934 raise ValueError('The cell Hamiltonian is not Hermitian.')
1935 _hop = syst.inter_cell_hopping(args, params=params)
1936 hop = np.empty(ham.shape, dtype=complex)
1937 hop[:, :_hop.shape[1]] = _hop
1938 hop[:, _hop.shape[1]:] = 0
1940 def h_k(k):
1941 # H_k = H_0 + V e^-ik + V^\dagger e^ik
1942 mat = hop * cmath.exp(-1j * k)
1943 mat += mat.conjugate().transpose() + ham
1944 return mat
1946 return spectrum(h_k, ('k', momenta), file=file, show=show, dpi=dpi,
1947 fig_size=fig_size, ax=ax)
1950def spectrum(syst, x, y=None, params=None, mask=None, file=None,
1951 show=True, dpi=None, fig_size=None, ax=None):
1952 """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters.
1954 This function requires either matplotlib or plotly to be installed.
1955 The default engine uses matplotlib for plotting.
1957 Parameters
1958 ----------
1959 syst : `kwant.system.FiniteSystem` or callable
1960 If a function, then it must take named parameters and return the
1961 Hamiltonian as a dense matrix.
1962 x : pair ``(name, values)``
1963 Parameter to ``ham`` that will be varied. Consists of the
1964 parameter name, and a sequence of parameter values.
1965 y : pair ``(name, values)``, optional
1966 Used for 3D plots (same as ``x``). If provided, then the cartesian
1967 product of the ``x`` values and these values will be used as a grid
1968 over which to evaluate the spectrum.
1969 params : dict, optional
1970 The rest of the parameters to ``ham``, which will be kept constant.
1971 mask : callable, optional
1972 Takes the parameters specified by ``x`` and ``y`` and returns True
1973 if the spectrum should not be calculated for the given parameter
1974 values.
1975 file : string or file object or `None`
1976 The output file. If `None`, output will be shown instead. If plotly is
1977 selected as the engine, the filename has to end with a html extension.
1978 show : bool
1979 For matplotlib engine, whether ``matplotlib.pyplot.show()`` is to be
1980 called, and the output is to be shown immediately.
1981 For the plotly engine, a call to ``iplot(fig)`` is made if
1982 show is True.
1983 Defaults to `True` for both engines.
1984 dpi : float
1985 Number of pixels per inch. If not set the ``matplotlib`` default is
1986 used.
1987 Only for matplotlib engine. If the plotly engine is selected and
1988 this argument is not None, then a RuntimeError will be triggered.
1989 fig_size : tuple
1990 Figure size `(width, height)` in inches. If not set, the default
1991 ``matplotlib`` value is used.
1992 Only for matplotlib engine. If the plotly engine is selected and
1993 this argument is not None, then a RuntimeError will be triggered.
1994 ax : ``matplotlib.axes.Axes`` instance or `None`
1995 If `ax` is not `None`, no new figure is created, but the plot is done
1996 within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
1997 and `fig_size` are ignored.
1998 Only for matplotlib engine. If the plotly engine is selected and
1999 this argument is not None, then a RuntimeError will be triggered.
2001 Returns
2002 -------
2003 fig : matplotlib figure or plotly Figure object
2004 """
2006 params = params or dict()
2008 if _p.engine == "matplotlib":
2009 return _spectrum_matplotlib(syst, x, y, params, mask, file,
2010 show, dpi, fig_size, ax)
2011 elif _p.engine == "plotly": 2011 ↛ 2014line 2011 didn't jump to line 2014, because the condition on line 2011 was never false
2012 _check_incompatible_args_plotly(dpi, fig_size, ax)
2013 return _spectrum_plotly(syst, x, y, params, mask, file, show)
2014 elif _p.engine is None:
2015 raise RuntimeError("Cannot use spectrum() without a plotting lib installed")
2016 else:
2017 raise RuntimeError("spectrum() does not support engine '{}'".format(_p.engine))
2020def _generate_spectrum(syst, params, mask, x, y):
2021 """Generates the spectrum dataset for the internal plotting
2022 functions of spectrum().
2024 Parameters
2025 ----------
2026 See spectrum(...) documentation.
2028 Returns
2029 -------
2030 spectrum : Numpy array
2031 The energies of the system calculated at each coordinate.
2032 planar : bool
2033 True if y is None
2034 array_values : tuple
2035 The coordinates of x, y values of the dataset for plotting.
2036 keys : tuple
2037 Labels for the x and y axes.
2038 """
2040 if system.is_finite(syst):
2041 def ham(**kwargs):
2042 return syst.hamiltonian_submatrix(params=kwargs, sparse=False)
2043 elif callable(syst): 2043 ↛ 2046line 2043 didn't jump to line 2046, because the condition on line 2043 was never false
2044 ham = syst
2045 else:
2046 raise TypeError("Expected 'syst' to be a finite Kwant system "
2047 "or a function.")
2049 planar = y is None
2050 keys = (x[0],) if planar else (x[0], y[0])
2051 array_values = (x[1],) if planar else (x[1], y[1])
2053 # calculate spectrum on the grid of points
2054 spectrum = []
2055 bound_ham = functools.partial(ham, **params)
2056 for point in itertools.product(*array_values):
2057 p = dict(zip(keys, point))
2058 if mask and mask(**p):
2059 spectrum.append(None)
2060 else:
2061 h_p = np.atleast_2d(bound_ham(**p))
2062 spectrum.append(np.linalg.eigvalsh(h_p))
2063 # massage masked grid points into a list of NaNs of the appropriate length
2064 n_eigvals = len(next(filter(lambda s: s is not None, spectrum)))
2065 nan_list = [np.nan] * n_eigvals
2066 spectrum = [nan_list if s is None else s for s in spectrum]
2067 # make into a numpy array and reshape
2068 new_shape = [len(v) for v in array_values] + [-1]
2069 spectrum = np.array(spectrum).reshape(new_shape)
2071 return spectrum, planar, array_values, keys
2074def _spectrum_plotly(syst, x, y=None, params=None, mask=None,
2075 file=None, show=True):
2076 """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters
2077 using the plotly engine.
2079 Parameters
2080 ----------
2081 See spectrum(...) documentation.
2083 Returns
2084 -------
2085 fig : plotly Figure / dict
2086 """
2088 spectrum, planar, array_values, keys = _generate_spectrum(syst, params,
2089 mask, x, y)
2091 if planar:
2092 fig = _p.plotly_graph_objs.Figure(data=[
2093 _p.plotly_graph_objs.Scatter(
2094 x=array_values[0],
2095 y=energies,
2096 ) for energies in spectrum.T
2097 ])
2098 fig.layout.xaxis.title = keys[0]
2099 fig.layout.yaxis.title = 'Energy'
2100 fig.layout.showlegend = False
2101 else:
2102 fig = _p.plotly_graph_objs.Figure(data=[
2103 _p.plotly_graph_objs.Surface(
2104 x=array_values[0],
2105 y=array_values[1],
2106 z=energies,
2107 cmax=np.max(spectrum),
2108 cmin=np.min(spectrum),
2109 ) for energies in spectrum.T
2110 ])
2111 fig.layout.scene.xaxis.title = keys[0]
2112 fig.layout.scene.yaxis.title = keys[1]
2113 fig.layout.scene.zaxis.title = 'Energy'
2115 fig.layout.title = (
2116 ', '.join('{} = {}'.format(*kv) for kv in params.items())
2117 )
2119 _maybe_output_fig(fig, file=file, show=show)
2121 return fig
2124def _spectrum_matplotlib(syst, x, y=None, params=None, mask=None, file=None,
2125 show=True, dpi=None, fig_size=None, ax=None):
2126 """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters
2127 using the matplotlib engine.
2129 Parameters
2130 ----------
2131 See spectrum(...) documentation.
2133 Returns
2134 -------
2135 fig : matplotlib figure
2136 A figure with the output if `ax` is not set, else None.
2137 """
2139 if y is not None and not _p.has3d: 2139 ↛ 2140line 2139 didn't jump to line 2140, because the condition on line 2139 was never true
2140 raise RuntimeError("Installed matplotlib does not support 3d plotting")
2142 spectrum, planar, array_values, keys = _generate_spectrum(syst, params,
2143 mask, x, y)
2145 # set up axes
2146 if ax is None:
2147 fig = _make_figure(dpi, fig_size, use_pyplot=(file is None))
2148 if planar:
2149 ax = fig.add_subplot(1, 1, 1)
2150 else:
2151 warnings.filterwarnings('ignore',
2152 message=r'.*mouse rotation disabled.*')
2153 ax = fig.add_subplot(1, 1, 1, projection='3d')
2154 warnings.resetwarnings()
2155 ax.set_xlabel(keys[0])
2156 if planar:
2157 ax.set_ylabel('Energy')
2158 else:
2159 ax.set_ylabel(keys[1])
2160 ax.set_zlabel('Energy')
2161 ax.set_title(
2162 ', '.join(
2163 '{} = {}'.format(key, value)
2164 for key, value in params.items()
2165 if not callable(value)
2166 )
2167 )
2168 else:
2169 fig = None
2171 # actually do the plot
2172 if planar:
2173 ax.plot(array_values[0], spectrum)
2174 else:
2175 if not hasattr(ax, 'plot_surface'):
2176 msg = ("When providing an axis for plotting over a 2D domain the "
2177 "axis should be created with 'projection=\"3d\"")
2178 raise TypeError(msg)
2179 # plot_surface cannot directly handle rank-3 values, so we
2180 # explicitly loop over the last axis
2181 grid = np.meshgrid(*array_values)
2182 with warnings.catch_warnings():
2183 warnings.filterwarnings('ignore', message='Z contains NaN values')
2184 for i in range(spectrum.shape[-1]):
2185 spec = spectrum[:, :, i].transpose() # row-major to x-y ordering
2186 ax.plot_surface(*(grid + [spec]), cstride=1, rstride=1)
2188 _maybe_output_fig(fig, file=file, show=show)
2190 return fig
2193# Smoothing functions used with 'interpolate_current'.
2195# Convolution kernel with finite support:
2196# f(r) = (1-r^2)^2 Θ(1-r^2)
2197def _bump(r):
2198 r[r > 1] = 1
2199 m = 1 - r * r
2200 return m * m
2203# We generate the smoothing function by convolving the current
2204# defined on a line between the two sites with
2205# f(ρ, z) = (1 - ρ^2 - z^2)^2 Θ(1 - ρ^2 - z^2), where ρ and z are
2206# cylindrical coords defined with respect to the hopping.
2207# 'F' is the result of the convolution.
2208def _smoothing(rho, z):
2209 r = 1 - rho * rho
2210 r[r < 0] = 0
2211 r = np.sqrt(r)
2212 m = np.clip(z, -r, r)
2213 rr = r * r
2214 rrrr = rr * rr
2215 mm = m * m
2216 return m * (mm * (mm/5 - (2/3) * rr) + rrrr) + (8 / 15) * rrrr * r
2219# We need to normalize the smoothing function so that it has unit cross
2220# section in the plane perpendicular to the hopping. This is equivalent
2221# to normalizing the integral of 'f' over the unit hypersphere to 1.
2222# The smoothing function goes as F(ρ) = (16/15) (1 - ρ^2)^(5/2) in the
2223# plane perpendicular to the hopping, so the cross section is:
2224# A_n = (16 / 15) * σ_n * ∫_0^1 ρ^(n-1) (1 - ρ^2)^(5/2) dρ
2225# where σ_n is the surface element prefactor (2 in 2D, 2π in 3D). Rather
2226# that calculate A_n every time, we hard code its value for 1, 2 and 3D.
2227_smoothing_cross_sections = [16 / 15, np.pi / 3, 32 * np.pi / 105]
2230# Determine the optimal bump function width from the absolute and
2231# relative widths provided, and the lengths of all the hoppings in the system
2232def _optimal_width(lens, abswidth, relwidth, bbox_size):
2233 if abswidth is None:
2234 if relwidth is None:
2235 unique_lens = np.unique(lens)
2236 longest = unique_lens[-1]
2237 for shortest_nonzero in unique_lens: 2237 ↛ 2240line 2237 didn't jump to line 2240, because the loop on line 2237 didn't complete
2238 if shortest_nonzero / longest > 1e-3: 2238 ↛ 2237line 2238 didn't jump to line 2237, because the condition on line 2238 was never false
2239 break
2240 width = 4 * shortest_nonzero
2241 else:
2242 width = relwidth * np.max(bbox_size)
2243 else:
2244 width = abswidth
2246 return width
2249# Create empty field array that covers the bounding box plus
2250# some additional padding
2251def _create_field(dim, bbox_size, width, n, is_current):
2252 field_shape = np.zeros(dim + 1, int)
2253 field_shape[dim] = dim if is_current else 1
2254 for d in range(dim):
2255 field_shape[d] = int(bbox_size[d] * n / width + n)
2256 if field_shape[d] % 2:
2257 field_shape[d] += 1
2258 field = np.zeros(field_shape)
2259 # padding is width / 2
2260 return field, width / 2
2263def density_kernel(coords):
2264 r = np.sqrt(np.sum(coords * coords))
2265 return _bump(r)[..., None]
2268def current_kernel(coords, direction, length):
2269 z = np.dot(coords, direction)
2270 rho = np.sqrt(np.abs(np.sum(coords * coords) - z * z))
2271 magn = (_smoothing(rho, z) - _smoothing(rho, z - length))
2272 return direction * magn[..., None]
2275# interpolate a discrete scalar or vector field.
2276def _interpolate_field(dim, elements, discrete_field, bbox, width,
2277 padding, field_out):
2279 field_shape = np.array(field_out.shape)
2280 bbox_min, bbox_max = bbox
2282 scale = 2 / width
2284 # if density elements is shape (nsites, dim)
2285 # if current elements is shape (nhops, 2, dim)
2286 assert elements.shape[-1] == dim
2287 is_current = len(elements.shape) == 3
2288 if is_current:
2289 assert elements.shape[1] == 2
2290 dirs = elements[:, 1] - elements[:, 0]
2291 lens = np.sqrt(np.sum(dirs * dirs, axis=-1))
2292 dirs /= lens[:, None]
2293 lens = lens * scale
2295 if is_current:
2296 pos_offsets = elements[:, 0] # first site in hopping
2297 kernel = current_kernel
2298 else:
2299 pos_offsets = elements # sites themselves
2300 kernel = density_kernel
2302 region = [np.linspace(bbox_min[d] - padding,
2303 bbox_max[d] + padding,
2304 field_shape[d])
2305 for d in range(dim)]
2307 grid_density = (field_shape[:dim] - 1) / (bbox_max + 2*padding - bbox_min)
2309 # slices for indexing 'field' and 'region' array
2310 slices = np.empty((len(discrete_field), dim, 2), int)
2311 if is_current:
2312 mn = np.min(elements, 1)
2313 mx = np.max(elements, 1)
2314 else:
2315 mn = mx = elements
2316 slices[:, :, 0] = np.floor((mn - bbox_min) * grid_density)
2317 slices[:, :, 1] = np.ceil((mx + 2*padding - bbox_min) * grid_density)
2319 for i in range(len(discrete_field)):
2321 if not np.diff(slices[i]).all() or not discrete_field[i]:
2322 # Zero volume or zero field: nothing to do.
2323 continue
2325 field_slice = tuple([slice(*slices[i, d]) for d in range(dim)])
2327 # Coordinates of the grid points that are within range of the current
2328 # hopping.
2329 coords = np.array(
2330 np.meshgrid(
2331 *[region[d][field_slice[d]] for d in range(dim)],
2332 sparse=True, indexing='ij'
2333 ),
2334 dtype=object
2335 )
2337 # Convert "coords" into scaled distances from pos_offset
2338 coords -= pos_offsets[i]
2339 coords *= scale
2340 magns = kernel(coords, dirs[i], lens[i]) if is_current else kernel(coords)
2341 magns *= discrete_field[i]
2343 field_out[field_slice] += magns
2345 field_out *= scale / _smoothing_cross_sections[dim - 1]
2348def interpolate_current(syst, current, relwidth=None, abswidth=None, n=9):
2349 """Interpolate currents in a system onto a regular grid.
2351 The system graph together with current intensities defines a "discrete"
2352 current density field where the current density is non-zero only on the
2353 straight lines that connect sites that are coupled by a hopping term.
2355 To make this vector field easier to visualize and interpret at different
2356 length scales, it is smoothed by convoluting it with the bell-shaped bump
2357 function ``f(r) = max(1 - (2*r / width)**2, 0)**2``. The bump width is
2358 determined by the `relwidth` and `abswidth` parameters.
2360 This routine samples the smoothed field on a regular (square or cubic)
2361 grid.
2363 Parameters
2364 ----------
2365 syst : A finalized system
2366 The system on which we are going to calculate the field.
2367 current : '1D array of float'
2368 Must contain the intensity on each hoppings in the same order that they
2369 appear in syst.graph.
2370 relwidth : float or `None`
2371 Relative width of the bumps used to generate the field, as a fraction
2372 of the length of the longest side of the bounding box. This argument
2373 is only used if `abswidth` is not given.
2374 abswidth : float or `None`
2375 Absolute width of the bumps used to generate the field. Takes
2376 precedence over `relwidth`. If neither is given, the bump width is set
2377 to four times the length of the shortest hopping.
2378 n : int
2379 Number of points the grid must have over the width of the bump.
2381 Returns
2382 -------
2383 field : n-d arraylike of float
2384 n-d array of n-d vectors.
2385 box : sequence of 2-sequences of float
2386 the extents of `field`: ((x0, x1), (y0, y1), ...)
2388 """
2389 if not builder.is_finite_system(syst):
2390 raise TypeError("The system needs to be finalized.")
2392 if len(current) != syst.graph.num_edges: 2392 ↛ 2393line 2392 didn't jump to line 2393, because the condition on line 2392 was never true
2393 raise ValueError("Current and hoppings arrays do not have the same"
2394 " length.")
2396 # hops: hoppings (pairs of points)
2397 dim = len(syst.sites[0].pos)
2398 hops = np.empty((syst.graph.num_edges // 2, 2, dim))
2399 # Take the average of the current flowing each way along the hoppings
2400 current_one_way = np.empty(syst.graph.num_edges // 2)
2401 seen_hoppings = dict()
2402 kprime = 0
2403 for k, (i, j) in enumerate(syst.graph):
2404 if (j, i) in seen_hoppings:
2405 current_one_way[seen_hoppings[j, i]] -= current[k]
2406 else:
2407 current_one_way[kprime] = current[k]
2408 hops[kprime][0] = syst.sites[j].pos
2409 hops[kprime][1] = syst.sites[i].pos
2410 seen_hoppings[i, j] = kprime
2411 kprime += 1
2412 current = current_one_way / 2
2414 min_hops = np.min(hops, 1)
2415 max_hops = np.max(hops, 1)
2416 bbox_min = np.min(min_hops, 0)
2417 bbox_max = np.max(max_hops, 0)
2418 bbox_size = bbox_max - bbox_min
2420 # lens: scaled lengths of hoppings
2421 # dirs: normalized directions of hoppings
2422 dirs = hops[:, 1] - hops[:, 0]
2423 lens = np.sqrt(np.sum(dirs * dirs, -1))
2424 dirs /= lens[:, None]
2425 width = _optimal_width(lens, abswidth, relwidth, bbox_size)
2428 field, padding = _create_field(dim, bbox_size, width, n, is_current=True)
2429 boundaries = tuple((bbox_min[d] - padding, bbox_max[d] + padding)
2430 for d in range(dim))
2431 _interpolate_field(dim, hops, current,
2432 (bbox_min, bbox_max), width, padding, field)
2434 return field, boundaries
2437def interpolate_density(syst, density, relwidth=None, abswidth=None, n=9,
2438 mask=True):
2439 """Interpolate density in a system onto a regular grid.
2441 The system sites together with a scalar for each site defines a "discrete"
2442 density field where the density is non-zero only at the site positions.
2444 To make this vector field easier to visualize and interpret at different
2445 length scales, it is smoothed by convoluting it with the bell-shaped bump
2446 function ``f(r) = max(1 - (2*r / width)**2, 0)**2``. The bump width is
2447 determined by the `relwidth` and `abswidth` parameters.
2449 This routine samples the smoothed field on a regular (square or cubic)
2450 grid.
2452 Parameters
2453 ----------
2454 syst : A finalized system
2455 The system on which we are going to calculate the field.
2456 density : 1D array of float
2457 Must contain the intensity on each site in the same order that they
2458 appear in syst.sites.
2459 relwidth : float, optional
2460 Relative width of the bumps used to smooth the field, as a fraction
2461 of the length of the longest side of the bounding box. This argument
2462 is only used if ``abswidth`` is not given.
2463 abswidth : float, optional
2464 Absolute width of the bumps used to smooth the field. Takes
2465 precedence over ``relwidth``. If neither is given, the bump width is set
2466 to four times the length of the shortest hopping.
2467 n : int
2468 Number of points the grid must have over the width of the bump.
2469 mask : Bool
2470 If True, this function returns a masked array that masks positions that
2471 are too far away from any sites. This is useful for showing an approximate
2472 outline of the system when the field is plotted.
2474 Returns
2475 -------
2476 field : n-d arraylike of float
2477 n-d array of n-d vectors.
2478 box : sequence of 2-sequences of float
2479 the extents of ``field``: ((x0, x1), (y0, y1), ...)
2481 """
2482 if not builder.is_finite_system(syst):
2483 raise TypeError("The system needs to be finalized.")
2485 if len(density) != len(syst.sites): 2485 ↛ 2486line 2485 didn't jump to line 2486, because the condition on line 2485 was never true
2486 raise ValueError("Density and sites arrays do not have the same"
2487 " length.")
2489 dim = len(syst.sites[0].pos)
2490 sites = np.array([s.pos for s in syst.sites])
2492 bbox_min = np.min(sites, axis=0)
2493 bbox_max = np.max(sites, axis=0)
2494 bbox_size = bbox_max - bbox_min
2496 # Determine the optimal width for the bump function
2497 dirs = np.array([syst.sites[i].pos - syst.sites[j].pos
2498 for i, j in syst.graph])
2499 lens = np.sqrt(np.sum(dirs * dirs, -1))
2500 width = _optimal_width(lens, abswidth, relwidth, bbox_size)
2502 field, padding = _create_field(dim, bbox_size, width, n, is_current=False)
2503 boundaries = tuple((bbox_min[d] - padding, bbox_max[d] + padding)
2504 for d in range(dim))
2505 _interpolate_field(dim, sites, density,
2506 (bbox_min, bbox_max), width, padding, field)
2508 if mask: 2508 ↛ 2516line 2508 didn't jump to line 2516, because the condition on line 2508 was never false
2509 # Field is zero when we are > 0.5*width from any site (as bump has
2510 # finite support), so we mask positions a little further than this.
2511 field = _mask(field,
2512 box=boundaries,
2513 coords=np.array([s.pos for s in syst.sites]),
2514 cutoff=0.6*width)
2516 return field, boundaries
2519def _gamma_compress(linear):
2520 """Compress linear sRGB into sRGB."""
2521 if linear <= 0.0031308:
2522 return 12.92 * linear
2523 else:
2524 a = 0.055
2525 return (1 + a) * linear ** (1 / 2.4) - a
2527_gamma_compress = np.vectorize(_gamma_compress, otypes=[float])
2530def _gamma_expand(corrected):
2531 """Expand sRGB into linear sRGB."""
2532 if corrected <= 0.04045:
2533 return corrected / 12.92
2534 else:
2535 a = 0.055
2536 return ((corrected + a) / (1 + a))**2.4
2538_gamma_expand = np.vectorize(_gamma_expand, otypes=[float])
2541def _linear_cmap(a, b):
2542 """Make a colormap that linearly interpolates between the colors a and b."""
2543 a = _p.matplotlib.colors.colorConverter.to_rgb(a)
2544 b = _p.matplotlib.colors.colorConverter.to_rgb(b)
2545 a_linear = _gamma_expand(a)
2546 b_linear = _gamma_expand(b)
2547 color_diff = a_linear - b_linear
2548 palette = (np.linspace(0, 1, 256).reshape((-1, 1))
2549 * color_diff.reshape((1, -1)))
2550 palette += b_linear
2551 palette = _gamma_compress(palette)
2552 return _p.matplotlib.colors.ListedColormap(palette)
2555def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k',
2556 max_linewidth=3, min_linewidth=1, density=2/9,
2557 colorbar=True, file=None,
2558 show=True, dpi=None, fig_size=None, ax=None,
2559 vmax=None):
2560 if _p.engine == "matplotlib": 2560 ↛ 2564line 2560 didn't jump to line 2564, because the condition on line 2560 was never false
2561 fig = _streamplot_matplotlib(field, box, cmap, bgcolor, linecolor,
2562 max_linewidth, min_linewidth, density, colorbar, file,
2563 show, dpi, fig_size, ax, vmax)
2564 elif _p.engine == "plotly":
2565 _check_incompatible_args_plotly(dpi, fig_size, ax)
2566 fig = _streamplot_plotly(field, box, cmap, bgcolor, linecolor,
2567 max_linewidth, min_linewidth, density,
2568 colorbar, file, show, vmax)
2569 elif _p.engine is None:
2570 raise RuntimeError("Cannot use streamplot() without a plotting lib installed")
2571 else:
2572 raise RuntimeError("streamplot() does not support engine '{}'".format(_p.engine))
2573 _maybe_output_fig(fig, file=file, show=show)
2576def _streamplot_plotly(field, box, cmap, bgcolor, linecolor,
2577 max_linewidth, min_linewidth, density,
2578 colorbar, file, show, vmax):
2579 raise RuntimeError("Streamplot() for plotly engine not implemented yet due to bug from plotly")
2582def _streamplot_matplotlib(field, box, cmap, bgcolor, linecolor,
2583 max_linewidth, min_linewidth, density, colorbar, file,
2584 show, dpi, fig_size, ax, vmax):
2585 """Draw streamlines of a flow field in Kwant style
2587 Solid colored streamlines are drawn, superimposed on a color plot of
2588 the flow speed that may be disabled by setting `bgcolor`. The width
2589 of the streamlines is proportional to the flow speed. Lines that
2590 would be thinner than `min_linewidth` are blended in a perceptually
2591 correct way into the background color in order to create the
2592 illusion of arbitrarily thin lines. (This is done because some plot
2593 engines like PDF do not support lines of arbitrarily thin width.)
2595 Internally, this routine uses matplotlib's streamplot.
2597 Parameters
2598 ----------
2599 field : 3d arraylike of float
2600 2d array of 2d vectors.
2601 box : 2-sequence of 2-sequences of float
2602 the extents of `field`: ((x0, x1), (y0, y1))
2603 cmap : colormap, optional
2604 Colormap for the background color plot. When not set the colormap
2605 "kwant_red" is used by default, unless `bgcolor` is set.
2606 bgcolor : color definition, optional
2607 The solid color of the background. Mutually exclusive with `cmap`.
2608 linecolor : color definition
2609 Color of the flow lines.
2610 max_linewidth : float
2611 Width of lines at maximum flow speed.
2612 min_linewidth : float
2613 Minimum width of lines before blending into the background color begins.
2614 density : float
2615 Number of flow lines per point of the field. The default value
2616 of 2/9 is chosen to show two lines per default width of the
2617 interpolation bump of `~kwant.plotter.interpolate_current`.
2618 colorbar : bool
2619 Whether to show a colorbar if a colormap is used. Ignored if `ax` is
2620 provided.
2621 file : string or file object or `None`
2622 The output file. If `None`, output will be shown instead.
2623 show : bool
2624 Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
2625 to be shown immediately. Defaults to `True`.
2626 dpi : float or `None`
2627 Number of pixels per inch. If not set the ``matplotlib`` default is
2628 used.
2629 fig_size : tuple or `None`
2630 Figure size `(width, height)` in inches. If not set, the default
2631 ``matplotlib`` value is used.
2632 ax : ``matplotlib.axes.Axes`` instance or `None`
2633 If `ax` is not `None`, no new figure is created, but the plot is done
2634 within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
2635 and `fig_size` are ignored.
2636 vmax : float or `None`
2637 The upper saturation limit for the colormap; flows higher than
2638 this will saturate. Note that there is no corresponding vmin
2639 option, vmin being fixed at zero.
2641 Returns
2642 -------
2643 fig : matplotlib figure
2644 A figure with the output if `ax` is not set, else None.
2645 """
2647 # Matplotlib's "density" is in units of 30 streamlines...
2648 density *= 1 / 30 * ta.array(field.shape[:2], int)
2650 # Matplotlib plots images like matrices: image[y, x]. We use the opposite
2651 # convention: image[x, y]. Hence, it is necessary to transpose.
2652 field = field.transpose(1, 0, 2)
2654 if field.shape[-1] != 2 or field.ndim != 3: 2654 ↛ 2655line 2654 didn't jump to line 2655, because the condition on line 2654 was never true
2655 raise ValueError("Only 2D field can be plotted.")
2657 if bgcolor is None: 2657 ↛ 2662line 2657 didn't jump to line 2662, because the condition on line 2657 was never false
2658 if cmap is None: 2658 ↛ 2660line 2658 didn't jump to line 2660, because the condition on line 2658 was never false
2659 cmap = _p.kwant_red_matplotlib
2660 cmap = _p.matplotlib.cm.get_cmap(cmap)
2661 bgcolor = cmap(0)[:3]
2662 elif cmap is not None:
2663 raise ValueError("The parameters 'cmap' and 'bgcolor' are "
2664 "mutually exclusive.")
2666 if ax is None:
2667 fig = _make_figure(dpi, fig_size, use_pyplot=(file is None))
2668 ax = fig.add_subplot(1, 1, 1, aspect='equal')
2669 else:
2670 fig = None
2672 X = np.linspace(*box[0], num=field.shape[1])
2673 Y = np.linspace(*box[1], num=field.shape[0])
2675 speed = np.linalg.norm(field, axis=-1)
2676 if vmax is None: 2676 ↛ 2679line 2676 didn't jump to line 2679, because the condition on line 2676 was never false
2677 vmax = np.max(speed) or 1
2679 if cmap is None: 2679 ↛ 2680line 2679 didn't jump to line 2680, because the condition on line 2679 was never true
2680 ax.set_axis_bgcolor(bgcolor)
2681 else:
2682 image = ax.imshow(speed, cmap=cmap,
2683 interpolation='bicubic',
2684 extent=[e for c in box for e in c],
2685 origin='lower', vmin=0, vmax=vmax)
2687 linewidth = max_linewidth / vmax * speed
2688 color = linewidth / min_linewidth
2689 thin = linewidth < min_linewidth
2690 linewidth[thin] = min_linewidth
2691 color[~ thin] = 1
2693 line_cmap = _linear_cmap(linecolor, bgcolor)
2695 ax.streamplot(X, Y, field[:,:,0], field[:,:,1],
2696 density=density, linewidth=linewidth,
2697 color=color, cmap=line_cmap, arrowstyle='->',
2698 norm=_p.matplotlib.colors.Normalize(0, 1))
2700 ax.set_xlim(*box[0])
2701 ax.set_ylim(*box[1])
2703 if colorbar and cmap and fig is not None:
2704 fig.colorbar(image)
2706 _maybe_output_fig(fig, file=file, show=show)
2708 return fig
2711def scalarplot(field, box,
2712 cmap=None, colorbar=True, file=None, show=True,
2713 dpi=None, fig_size=None, ax=None, vmin=None, vmax=None,
2714 background='#e0e0e0'):
2715 """Draw a scalar field in Kwant style
2717 Internally, this routine uses matplotlib's imshow.
2719 Parameters
2720 ----------
2721 field : 2d arraylike of float
2722 2d scalar field to plot.
2723 box : pair of pair of float
2724 the realspace extents of ``field``: ((x0, x1), (y0, y1))
2725 cmap : colormap, optional
2726 Colormap for the background color plot. When not set the colormap
2727 "kwant_red" is used by default.
2728 colorbar : bool, default: True
2729 Whether to show a colorbar if a colormap is used. Ignored if `ax` is
2730 provided.
2731 file : string or file object, optional
2732 The output file. If not provided, output will be shown instead.
2733 show : bool, default: True
2734 Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
2735 to be shown immediately.
2736 dpi : float, optional
2737 Number of pixels per inch. If not set the ``matplotlib`` default is
2738 used.
2739 fig_size : tuple, optional
2740 Figure size ``(width, height)`` in inches. If not set, the default
2741 ``matplotlib`` value is used.
2742 ax : ``matplotlib.axes.Axes`` instance, optional
2743 If ``ax`` is provided, no new figure is created, but the plot is done
2744 within the existing Axes ``ax``. in this case, ``file``, ``show``,
2745 ``dpi`` and ``fig_size`` are ignored.
2746 vmin, vmax : float, optional
2747 The lower/upper saturation limit for the colormap.
2748 background : matplotlib color spec
2749 Areas outside the system are filled with this color.
2751 Returns
2752 -------
2753 fig : matplotlib figure
2754 A figure with the output if ``ax`` is not set, else None.
2755 """
2757 # Matplotlib plots images like matrices: image[y, x]. We use the opposite
2758 # convention: image[x, y]. Hence, it is necessary to transpose.
2759 # Also squeeze out the last axis as it is just a scalar field
2761 field = field.squeeze(axis=-1).transpose()
2763 if field.ndim != 2:
2764 raise ValueError("Only 2D field can be plotted.")
2766 if vmin is None:
2767 vmin = np.min(field)
2768 if vmax is None:
2769 vmax = np.max(field)
2771 if _p.engine == "matplotlib":
2772 fig = _scalarplot_matplotlib(field, box, cmap, colorbar,
2773 file, show, dpi, fig_size, ax,
2774 vmin, vmax, background)
2775 elif _p.engine == "plotly":
2776 _check_incompatible_args_plotly(dpi, fig_size, ax)
2777 fig = _scalarplot_plotly(field, box, cmap, colorbar, file,
2778 show, vmin, vmax, background)
2779 elif _p.engine is None:
2780 raise RuntimeError("Cannot use scalarplot() without a plotting lib installed")
2781 else:
2782 raise RuntimeError("scalarplot() does not support engine '{}'".format(_p.engine))
2783 _maybe_output_fig(fig, file=file, show=show)
2785 return fig
2788def _scalarplot_plotly(field, box, cmap, colorbar, file,
2789 show, vmin, vmax, background):
2791 if cmap is None:
2792 cmap = _p.kwant_red_plotly
2794 contour_object = _p.plotly_graph_objs.Heatmap()
2795 contour_object.z = field
2796 contour_object.x = np.linspace(*box[0],field.shape[0])
2797 contour_object.y = np.linspace(*box[1],field.shape[1])
2798 contour_object.zsmooth = 'best'
2799 contour_object.colorscale = cmap
2800 contour_object.zmax = vmax
2801 contour_object.zmin = vmin
2803 contour_object.showscale = colorbar
2805 fig = _p.plotly_graph_objs.Figure(data=[contour_object])
2806 fig.layout.plot_bgcolor = background
2808 return fig
2811def _scalarplot_matplotlib(field, box, cmap, colorbar, file, show, dpi,
2812 fig_size, ax, vmin, vmax, background):
2814 if cmap is None:
2815 cmap = _p.kwant_red_matplotlib
2816 cmap = _p.matplotlib.cm.get_cmap(cmap)
2818 if ax is None:
2819 fig = _make_figure(dpi, fig_size, use_pyplot=(file is None))
2820 ax = fig.add_subplot(1, 1, 1, aspect='equal')
2821 else:
2822 fig = None
2824 image = ax.imshow(field, cmap=cmap,
2825 interpolation='bicubic',
2826 extent=[e for c in box for e in c],
2827 origin='lower', vmin=vmin, vmax=vmax)
2829 ax.set_xlim(*box[0])
2830 ax.set_ylim(*box[1])
2831 ax.patch.set_facecolor(background)
2833 if colorbar and cmap and fig is not None:
2834 fig.colorbar(image)
2836 return fig
2839def current(syst, current, relwidth=0.05, **kwargs):
2840 """Show an interpolated current defined for the hoppings of a system.
2842 The system graph together with current intensities defines a "discrete"
2843 current density field where the current density is non-zero only on the
2844 straight lines that connect sites that are coupled by a hopping term.
2846 To make this scalar field easier to visualize and interpret at different
2847 length scales, it is smoothed by convoluting it with the bell-shaped bump
2848 function ``f(r) = max(1 - (2*r / width)**2, 0)**2``. The bump width is
2849 determined by the ``relwidth`` parameter.
2851 This routine samples the smoothed field on a regular (square or cubic) grid
2852 and displays it using an enhanced variant of matplotlib's streamplot.
2854 This is a convenience function that is equivalent to
2855 ``streamplot(*interpolate_current(syst, current, relwidth), **kwargs)``.
2856 The longer form makes it possible to tweak additional options of
2857 `~kwant.plotter.interpolate_current`.
2859 Parameters
2860 ----------
2861 syst : `kwant.system.FiniteSystem`
2862 The system for which to plot the ``current``.
2863 current : sequence of float
2864 Sequence of values defining currents on each hopping of the system.
2865 Ordered in the same way as ``syst.graph``. This typically will be
2866 the result of evaluating a `~kwant.operator.Current` operator.
2867 relwidth : float or `None`
2868 Relative width of the bumps used to smooth the field, as a fraction
2869 of the length of the longest side of the bounding box.
2870 **kwargs : various
2871 Keyword args to be passed verbatim to `kwant.plotter.streamplot`.
2873 Returns
2874 -------
2875 fig : matplotlib figure
2876 A figure with the output if ``ax`` is not set, else None.
2878 See Also
2879 --------
2880 kwant.plotter.density
2881 """
2882 with _common.reraise_warnings(4):
2883 return streamplot(*interpolate_current(syst, current, relwidth),
2884 **kwargs)
2887def _mask(field, box, coords, cutoff):
2888 tree = spatial.cKDTree(coords)
2890 # Build the mask initially as a 2D array
2891 dims = tuple(slice(boxmin, boxmax, 1j * shape)
2892 for (boxmin, boxmax), shape in zip(box, field.shape))
2893 mask = np.mgrid[dims].reshape(len(box), -1).T
2895 mask = tree.query(mask, distance_upper_bound=cutoff)[0] == np.inf
2896 return np.ma.masked_array(field, mask)
2899def density(syst, density, relwidth=0.05, **kwargs):
2900 """Show an interpolated density defined on the sites of a system.
2902 The system sites, together with a scalar per site defines a "discrete"
2903 density field that is non-zero only on the sites.
2905 To make this scalar field easier to visualize and interpret at different
2906 length scales, it is smoothed by convoluting it with the bell-shaped bump
2907 function ``f(r) = max(1 - (2*r / width)**2, 0)**2``. The bump width is
2908 determined by the ``relwidth`` parameter.
2910 This routine samples the smoothed field on a regular (square or cubic) grid
2911 and displays it using matplotlib's imshow.
2913 This function is similar to `~kwant.plotter.map`, but generally gives more
2914 appealing visual results when used on systems with many sites. If you want
2915 site-level resolution you may be better off using `~kwant.plotter.map`.
2917 This is a convenience function that is equivalent to
2918 ``scalarplot(*interpolate_density(syst, density, relwidth), **kwargs)``.
2919 The longer form makes it possible to tweak additional options of
2920 `~kwant.plotter.interpolate_density`.
2922 Parameters
2923 ----------
2924 syst : `kwant.system.FiniteSystem`
2925 The system for which to plot ``density``.
2926 density : sequence of float
2927 Sequence of values defining density on each site of the system.
2928 Ordered in the same way as ``syst.sites``. This typically will be
2929 the result of evaluating a `~kwant.operator.Density` operator.
2930 relwidth : float or `None`
2931 Relative width of the bumps used to smooth the field, as a fraction
2932 of the length of the longest side of the bounding box.
2933 **kwargs : various
2934 Keyword args to be passed verbatim to `~kwant.plotter.scalarplot`.
2936 Returns
2937 -------
2938 fig : matplotlib figure
2939 A figure with the output if ``ax`` is not set, else None.
2941 See Also
2942 --------
2943 kwant.plotter.current
2944 kwant.plotter.map
2945 """
2946 with _common.reraise_warnings(4):
2947 return scalarplot(*interpolate_density(syst, density, relwidth),
2948 **kwargs)
2951# TODO (Anton): Fix plotting of parts of the system using color = np.nan.
2952# Not plotting sites currently works, not plotting hoppings does not.
2953# TODO (Anton): Allow a more flexible treatment of position than pos_transform
2954# (an interface for user-defined pos).