Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2# Copyright 2011-2019 Kwant authors. 

3# 

4# This file is part of Kwant. It is subject to the license terms in the file 

5# LICENSE.rst found in the top-level directory of this distribution and at 

6# https://kwant-project.org/license. A list of Kwant authors can be found in 

7# the file AUTHORS.rst at the top-level directory of this distribution and at 

8# https://kwant-project.org/authors. 

9 

10"""Plotter module for Kwant. 

11 

12This module provides iterators useful for any plotter routine, such as a list 

13of system sites, their coordinates, lead sites at any lead unit cell, etc. If 

14`matplotlib` is available, it also provides simple functions for plotting the 

15system in two or three dimensions. 

16""" 

17 

18from collections import defaultdict 

19import sys 

20import itertools 

21import functools 

22import warnings 

23import cmath 

24import numpy as np 

25import tinyarray as ta 

26from scipy import spatial, interpolate 

27from math import cos, sin, pi, sqrt 

28 

29from . import system, builder, _common 

30from ._common import deprecate_args 

31 

32 

33__all__ = ['set_engine', 'get_engine', 

34 'plot', 'map', 'bands', 'spectrum', 'current', 'density', 

35 'interpolate_current', 'interpolate_density', 

36 'streamplot', 'scalarplot', 

37 'sys_leads_sites', 'sys_leads_hoppings', 'sys_leads_pos', 

38 'sys_leads_hopping_pos', 'mask_interpolate'] 

39 

40# All the expensive imports are done in _plotter.py. We lazy load the module 

41# to avoid slowing down the initial import of Kwant. 

42_p = _common.lazy_import('_plotter') 

43 

44 

45def set_engine(engine): 

46 """Set the plotting engine to use. 

47 

48 Parameters 

49 ---------- 

50 engine : str 

51 Options are: 'matplotlib', 'plotly'. 

52 """ 

53 

54 if ((_p.mpl_available) or (_p.plotly_available)): 54 ↛ 65line 54 didn't jump to line 65, because the condition on line 54 was never false

55 try: 

56 assert(engine in _p.engines) 

57 _p.engine = engine 

58 except: 

59 error_message = "Tried to set an unknown engine \'{}\'.".format( 

60 engine) 

61 error_message += " Supported engines are {}".format( 

62 [e for e in _p.engines]) 

63 raise RuntimeError(error_message) 

64 else: 

65 warnings.warn("Tried to set \'{}\' but is not " 

66 "available.".format(engine), RuntimeWarning) 

67 

68 if ((_p.engine == "plotly") and 

69 (not _p.init_notebook_mode_set)): 

70 if (_p.is_ipython_kernel): 70 ↛ 71line 70 didn't jump to line 71, because the condition on line 70 was never true

71 _p.init_notebook_mode_set = True 

72 _p.plotly_module.init_notebook_mode(connected=True) 

73 

74 

75def get_engine(): 

76 return _p.engine 

77 

78 

79def _check_incompatible_args_plotly(dpi, fig_size, ax): 

80 assert(_p.engine == "plotly") 

81 if(dpi or fig_size or ax): 81 ↛ 82line 81 didn't jump to line 82, because the condition on line 81 was never true

82 raise RuntimeError( 

83 "Plotly engine does not support setting 'dpi', 'fig_size' " 

84 "or 'ax', either leave these parameters unspecified, or " 

85 "select the matplotlib engine with" 

86 "'kwant.plotter.set_engine(\"matplotlib\")'") 

87 

88 

89def _sample_array(array, n_samples, rng=None): 

90 rng = _common.ensure_rng(rng) 

91 la = len(array) 

92 return array[rng.choice(range(la), min(n_samples, la), replace=False)] 

93 

94 

95# matplotlib helper functions. 

96 

97def _color_cycle(): 

98 """Infinitely cycle through colors from the matplotlib color cycle.""" 

99 props = _p.matplotlib.rcParams['axes.prop_cycle'] 

100 return itertools.cycle(x['color'] for x in props) 

101 

102 

103def _make_figure(dpi, fig_size, use_pyplot=False): 

104 if 'matplotlib.backends' not in sys.modules: 104 ↛ 105line 104 didn't jump to line 105, because the condition on line 104 was never true

105 warnings.warn( 

106 "Kwant's plotting functions have\nthe side effect of " 

107 "selecting the matplotlib backend. To avoid this " 

108 "warning,\nimport matplotlib.pyplot, " 

109 "matplotlib.backends or call matplotlib.use().", 

110 RuntimeWarning, stacklevel=3 

111 ) 

112 if use_pyplot: 112 ↛ 117line 112 didn't jump to line 117, because the condition on line 112 was never true

113 # We import backends and pyplot only at the last possible moment (=now) 

114 # because this has the side effect of selecting the matplotlib backend 

115 # for good. Warn if backend has not been set yet. This check is the 

116 # same as the one performed inside matplotlib.use. 

117 from matplotlib import pyplot 

118 fig = pyplot.figure() 

119 else: 

120 from matplotlib.backends.backend_agg import FigureCanvasAgg 

121 fig = _p.Figure() 

122 fig.canvas = FigureCanvasAgg(fig) 

123 if dpi is not None: 

124 fig.set_dpi(dpi) 

125 if fig_size is not None: 

126 fig.set_figwidth(fig_size[0]) 

127 fig.set_figheight(fig_size[1]) 

128 return fig 

129 

130 

131def _maybe_output_fig(fig, file=None, show=True): 

132 """Output a matplotlib figure using a given output mode. 

133 

134 Parameters 

135 ---------- 

136 fig : matplotlib.figure.Figure instance 

137 The figure to be output. 

138 file : string or a file object 

139 The name of the target file or the target file itself 

140 (opened for writing). 

141 show : bool 

142 Whether to call ``matplotlib.pyplot.show()``. Only has an effect if 

143 not saving to a file. 

144 

145 Notes 

146 ----- 

147 The behavior of this function producing a file is different from that of 

148 matplotlib in that the `dpi` attribute of the figure is used by defaul 

149 instead of the matplotlib config setting. 

150 """ 

151 if fig is None: 

152 return 

153 

154 if _p.engine == "matplotlib": 

155 if file is not None: 155 ↛ 157line 155 didn't jump to line 157, because the condition on line 155 was never false

156 fig.canvas.print_figure(file, dpi=fig.dpi) 

157 elif show: 

158 # If there was no file provided, pyplot should already be available 

159 # and we can import it safely without additional warnings. 

160 from matplotlib import pyplot 

161 pyplot.show() 

162 elif _p.engine == "plotly": 162 ↛ exitline 162 didn't return from function '_maybe_output_fig', because the condition on line 162 was never false

163 if file is not None: 163 ↛ 165line 163 didn't jump to line 165, because the condition on line 163 was never false

164 _p.plotly_module.plot(fig, show_link=False, filename=file, auto_open=False) 

165 if show: 165 ↛ 166line 165 didn't jump to line 166, because the condition on line 165 was never true

166 if (_p.is_ipython_kernel): 

167 _p.plotly_module.iplot(fig) 

168 else: 

169 raise RuntimeError('show flag using the plotly engine can ' 

170 'only be True if and only if called from a ' 

171 'jupyter/ipython environment.') 

172 

173 

174def set_colors(color, collection, cmap, norm=None): 

175 """Process a color specification to a format accepted by collections. 

176 

177 Parameters 

178 ---------- 

179 color : color specification 

180 collection : instance of a subclass of ``matplotlib.collections.Collection`` 

181 Collection to which the color is added. 

182 cmap : ``matplotlib`` color map specification or None 

183 Color map to be used if colors are specified as floats. 

184 norm : ``matplotlib`` color norm 

185 Norm to be used if colors are specified as floats. 

186 """ 

187 

188 length = max(len(collection.get_paths()), len(collection.get_offsets())) 

189 

190 # matplotlib gets confused if dtype='object' 

191 if (isinstance(color, np.ndarray) and color.dtype == np.dtype('object')): 191 ↛ 192line 191 didn't jump to line 192, because the condition on line 191 was never true

192 color = tuple(color) 

193 

194 if _p.has3d and isinstance(collection, _p.mplot3d.art3d.Line3DCollection): 194 ↛ 195line 194 didn't jump to line 195, because the condition on line 194 was never true

195 length = len(collection._segments3d) # Once again, matplotlib fault! 

196 

197 if _p.isarray(color) and len(color) == length: 

198 try: 

199 # check if it is an array of floats for color mapping 

200 color = np.asarray(color, dtype=float) 

201 if color.ndim == 1: 

202 collection.set_array(color) 

203 collection.set_cmap(cmap) 

204 collection.set_norm(norm) 

205 collection.set_color(None) 

206 return 

207 except (TypeError, ValueError): 

208 pass 

209 

210 colors = _p.matplotlib.colors.colorConverter.to_rgba_array(color) 

211 collection.set_color(colors) 

212 

213 

214def percentile_bound(data, vmin, vmax, percentile=96, stretch=0.1): 

215 """Return the bounds that captures at least 'percentile' of 'data'. 

216 

217 If 'vmin' or 'vmax' are provided, then the corresponding bound is 

218 exactly 'vmin' or 'vmax'. First we set the bounds such that the 

219 provided percentile of the data is within them. Then we try to 

220 extend the bounds to cover all the data, maximally stretching each 

221 bound by a factor 'stretch'. 

222 """ 

223 if vmin is not None and vmax is not None: 223 ↛ 224line 223 didn't jump to line 224, because the condition on line 223 was never true

224 return vmin, vmax 

225 

226 percentile = (100 - percentile) / 2 

227 percentiles = (0, percentile, 100 - percentile, 100) 

228 mn, bound_mn, bound_mx, mx = np.percentile(data.flatten(), percentiles) 

229 

230 bound_mn = bound_mn if vmin is None else vmin 

231 bound_mx = bound_mx if vmax is None else vmax 

232 

233 # Stretch the lower and upper bounds to cover all the data, if 

234 # we stretch the bound by less than a factor 'stretch'. 

235 stretch = (bound_mx - bound_mn) * stretch 

236 out_mn = max(bound_mn - stretch, mn) if vmin is None else vmin 

237 out_mx = min(bound_mx + stretch, mx) if vmax is None else vmax 

238 

239 return (out_mn, out_mx) 

240 

241 

242symbol_dict = {'O': 'o', 's': ('p', 4, 45), 'S': ('P', 4, 45)} 

243 

244def get_symbol(symbols): 

245 """Return the path corresponding to the description in ``symbols``""" 

246 # Figure out if list of symbols or single symbol. 

247 if not hasattr(symbols, '__getitem__'): 247 ↛ 248line 247 didn't jump to line 248, because the condition on line 247 was never true

248 symbols = [symbols] 

249 elif len(symbols) == 3 and symbols[0] in ('p', 'P'): 249 ↛ 252line 249 didn't jump to line 252, because the condition on line 249 was never true

250 # Most likely a polygon specification (at least not a valid other 

251 # symbol). 

252 symbols = [symbols] 

253 

254 symbols = [symbol_dict[symbol] if symbol in symbol_dict else symbol for 

255 symbol in symbols] 

256 

257 paths = [] 

258 for symbol in symbols: 

259 if isinstance(symbol, _p.matplotlib.path.Path): 259 ↛ 260line 259 didn't jump to line 260, because the condition on line 259 was never true

260 return symbol 

261 elif hasattr(symbol, '__getitem__') and len(symbol) == 3: 261 ↛ 262line 261 didn't jump to line 262, because the condition on line 261 was never true

262 kind, n, angle = symbol 

263 

264 if kind in ['p', 'P']: 

265 if kind == 'p': 

266 radius = 1. / cos(pi / n) 

267 else: 

268 # make the polygon such that it has area equal 

269 # to a unit circle 

270 radius = sqrt(2 * pi / (n * sin(2 * pi / n))) 

271 

272 angle = pi * angle / 180 

273 patch = _p.matplotlib.patches.RegularPolygon((0, 0), n, 

274 radius=radius, 

275 orientation=angle) 

276 else: 

277 raise ValueError("Unknown symbol definition " + str(symbol)) 

278 elif symbol == 'o': 278 ↛ 281line 278 didn't jump to line 281, because the condition on line 278 was never false

279 patch = _p.matplotlib.patches.Circle((0, 0), 1) 

280 

281 paths.append(patch.get_path().transformed(patch.get_transform())) 

282 

283 return paths 

284 

285 

286def symbols(axes, pos, symbol='o', size=1, reflen=None, facecolor='k', 

287 edgecolor='k', linewidth=None, cmap=None, norm=None, zorder=0, 

288 **kwargs): 

289 """Add a collection of symbols (2D or 3D) to an axes instance. 

290 

291 Parameters 

292 ---------- 

293 axes : matplotlib.axes.Axes instance 

294 Axes to which the lines have to be added. 

295 pos0 : 2d or 3d array_like 

296 Coordinates of each symbol. 

297 symbol: symbol definition. 

298 TODO To be written. 

299 size: float or 1d array 

300 Size(s) of the symbols. Defaults to 1. 

301 reflen: float or None, optional 

302 If ``reflen`` is ``None``, the symbol sizes and linewidths are 

303 given in points (absolute size in the figure space). If 

304 ``reflen`` is a number, the symbol sizes and linewidths are 

305 given in units of ``reflen`` in data space (i.e. scales with the 

306 scale of the plot). Defaults to ``None``. 

307 facecolor: color definition, optional 

308 edgecolor: color definition, optional 

309 Defines the fill and edge color of the symbol, repsectively. 

310 Either a single object that is a proper matplotlib color 

311 definition or a sequence of such objects of appropriate 

312 length. Defaults to all black. 

313 cmap : ``matplotlib`` color map specification or None 

314 Color map to be used if colors are specified as floats. 

315 norm : ``matplotlib`` color norm 

316 Norm to be used if colors are specified as floats. 

317 zorder: int 

318 Order in which different collections are drawn: larger 

319 ``zorder`` means the collection is drawn over collections with 

320 smaller ``zorder`` values. 

321 **kwargs : dict keyword arguments to 

322 pass to `PathCollection` or `Path3DCollection`, respectively. 

323 

324 Returns 

325 ------- 

326 `PathCollection` or `Path3DCollection` instance containing all the 

327 symbols that were added. 

328 """ 

329 

330 dim = pos.shape[1] 

331 assert dim == 2 or dim == 3 

332 

333 #internally, size must be array_like 

334 try: 

335 size[0] 

336 except TypeError: 

337 size = (size, ) 

338 

339 if dim == 2: 

340 Collection = _p.PathCollection 

341 else: 

342 Collection = _p.Path3DCollection 

343 

344 if len(pos) == 0 or np.all(symbol == 'no symbol') or np.all(size == 0): 344 ↛ 345line 344 didn't jump to line 345, because the condition on line 344 was never true

345 paths = [] 

346 pos = np.empty((0, dim)) 

347 else: 

348 paths = get_symbol(symbol) 

349 

350 coll = Collection(paths, sizes=size, reflen=reflen, linewidths=linewidth, 

351 offsets=pos, transOffset=axes.transData, zorder=zorder) 

352 

353 set_colors(facecolor, coll, cmap, norm) 

354 coll.set_edgecolors(edgecolor) 

355 

356 coll.update(kwargs) 

357 

358 if dim == 2: 

359 axes.add_collection(coll) 

360 else: 

361 axes.add_collection3d(coll) 

362 

363 return coll 

364 

365 

366def lines(axes, pos0, pos1, reflen=None, colors='k', linestyles='solid', 

367 cmap=None, norm=None, zorder=0, **kwargs): 

368 """Add a collection of line segments (2D or 3D) to an axes instance. 

369 

370 Parameters 

371 ---------- 

372 axes : matplotlib.axes.Axes instance 

373 Axes to which the lines have to be added. 

374 pos0 : 2d or 3d array_like 

375 Starting coordinates of each line segment 

376 pos1 : 2d or 3d array_like 

377 Ending coordinates of each line segment 

378 reflen: float or None, optional 

379 If `reflen` is `None`, the linewidths are given in points (absolute 

380 size in the figure space). If `reflen` is a number, the linewidths 

381 are given in units of `reflen` in data space (i.e. scales with 

382 the scale of the plot). Defaults to `None`. 

383 colors : color definition, optional 

384 Either a single object that is a proper matplotlib color definition 

385 or a sequence of such objects of appropriate length. Defaults to all 

386 segments black. 

387 linestyles :linestyle definition, optional 

388 Either a single object that is a proper matplotlib line style 

389 definition or a sequence of such objects of appropriate length. 

390 Defaults to all segments solid. 

391 cmap : ``matplotlib`` color map specification or None 

392 Color map to be used if colors are specified as floats. 

393 norm : ``matplotlib`` color norm 

394 Norm to be used if colors are specified as floats. 

395 zorder: int 

396 Order in which different collections are drawn: larger 

397 `zorder` means the collection is drawn over collections with 

398 smaller `zorder` values. 

399 **kwargs : dict keyword arguments to 

400 pass to `LineCollection` or `Line3DCollection`, respectively. 

401 

402 Returns 

403 ------- 

404 `LineCollection` or `Line3DCollection` instance containing all the 

405 segments that were added. 

406 """ 

407 

408 if not pos0.shape == pos1.shape: 408 ↛ 409line 408 didn't jump to line 409, because the condition on line 408 was never true

409 raise ValueError('Incompatible lengths of coordinate arrays.') 

410 

411 dim = pos0.shape[1] 

412 assert dim == 2 or dim == 3 

413 if dim == 2: 

414 Collection = _p.LineCollection 

415 else: 

416 Collection = _p.Line3DCollection 

417 

418 if (len(pos0) == 0 or 

419 ('linewidths' in kwargs and kwargs['linewidths'] == 0)): 

420 coll = Collection([], reflen=reflen, linestyles=linestyles, 

421 zorder=zorder) 

422 coll.update(kwargs) 

423 if dim == 2: 

424 axes.add_collection(coll) 

425 else: 

426 axes.add_collection3d(coll) 

427 return coll 

428 

429 segments = np.c_[pos0, pos1].reshape(pos0.shape[0], 2, dim) 

430 

431 coll = Collection(segments, reflen=reflen, linestyles=linestyles, 

432 zorder=zorder) 

433 set_colors(colors, coll, cmap, norm) 

434 coll.update(kwargs) 

435 

436 if dim == 2: 436 ↛ 439line 436 didn't jump to line 439, because the condition on line 436 was never false

437 axes.add_collection(coll) 

438 else: 

439 axes.add_collection3d(coll) 

440 

441 return coll 

442 

443 

444# Extracting necessary data from the system. 

445 

446def sys_leads_sites(sys, num_lead_cells=2): 

447 """Return all the sites of the system and of the leads as a list. 

448 

449 Parameters 

450 ---------- 

451 sys : kwant.builder.Builder or kwant.system.System instance 

452 The system, sites of which should be returned. 

453 num_lead_cells : integer 

454 The number of times lead sites from each lead should be returned. 

455 This is useful for showing several unit cells of the lead next to the 

456 system. 

457 

458 Returns 

459 ------- 

460 sites : list of (site, lead_number, copy_number) tuples 

461 A site is a `~kwant.system.Site` instance if the system is not finalized, 

462 and an integer otherwise. For system sites `lead_number` is `None` and 

463 `copy_number` is `0`, for leads both are integers. 

464 lead_cells : list of slices 

465 `lead_cells[i]` gives the position of all the coordinates of lead 

466 `i` within `sites`. 

467 

468 Notes 

469 ----- 

470 Leads are only supported if they are of the same type as the original 

471 system, i.e. sites of `~kwant.builder.BuilderLead` leads are returned with an 

472 unfinalized system, and sites of ``system.InfiniteSystem`` leads are 

473 returned with a finalized system. 

474 """ 

475 syst = sys # for naming consistency within function bodies 

476 lead_cells = [] 

477 if isinstance(syst, builder.Builder): 

478 sites = [(site, None, 0) for site in syst.sites()] 

479 for leadnr, lead in enumerate(syst.leads): 

480 start = len(sites) 

481 if hasattr(lead, 'builder') and len(lead.interface): 481 ↛ 485line 481 didn't jump to line 485, because the condition on line 481 was never false

482 sites.extend(((site, leadnr, i) for site in 

483 lead.builder.sites() for i in 

484 range(num_lead_cells))) 

485 lead_cells.append(slice(start, len(sites))) 

486 elif system.is_finite(syst): 486 ↛ 498line 486 didn't jump to line 498, because the condition on line 486 was never false

487 sites = [(i, None, 0) for i in range(syst.graph.num_nodes)] 

488 for leadnr, lead in enumerate(syst.leads): 

489 start = len(sites) 

490 # We will only plot leads with a graph and with a symmetry. 

491 if (hasattr(lead, 'graph') and hasattr(lead, 'symmetry') and 491 ↛ 496line 491 didn't jump to line 496, because the condition on line 491 was never false

492 len(syst.lead_interfaces[leadnr])): 

493 sites.extend(((site, leadnr, i) for site in 

494 range(lead.cell_size) for i in 

495 range(num_lead_cells))) 

496 lead_cells.append(slice(start, len(sites))) 

497 else: 

498 raise TypeError('Unrecognized system type.') 

499 return sites, lead_cells 

500 

501 

502def sys_leads_pos(sys, site_lead_nr): 

503 """Return an array of positions of sites in a system. 

504 

505 Parameters 

506 ---------- 

507 sys : `kwant.builder.Builder` or `kwant.system.System` instance 

508 The system, coordinates of sites of which should be returned. 

509 site_lead_nr : list of `(site, leadnr, copynr)` tuples 

510 Output of `sys_leads_sites` applied to the system. 

511 

512 Returns 

513 ------- 

514 coords : numpy.ndarray of floats 

515 Array of coordinates of the sites. 

516 

517 Notes 

518 ----- 

519 This function uses `site.pos` property to get the position of a builder 

520 site and `sys.pos(sitenr)` for finalized systems. This function requires 

521 that all the positions of all the sites have the same dimensionality. 

522 """ 

523 

524 # Note about efficiency (also applies to sys_leads_hoppings_pos) 

525 # NumPy is really slow when making a NumPy array from a tinyarray 

526 # (buffer interface seems very slow). It's much faster to first 

527 # convert to a tuple and then to convert to numpy array ... 

528 

529 syst = sys # for naming consistency inside function bodies 

530 is_builder = isinstance(syst, builder.Builder) 

531 num_lead_cells = site_lead_nr[-1][2] + 1 

532 if is_builder: 

533 pos = np.array(ta.array([i[0].pos for i in site_lead_nr])) 

534 else: 

535 syst_from_lead = lambda lead: (syst if (lead is None) 

536 else syst.leads[lead]) 

537 pos = np.array(ta.array([syst_from_lead(i[1]).pos(i[0]) 

538 for i in site_lead_nr])) 

539 if pos.dtype == object: # Happens if not all the pos are same length. 539 ↛ 540line 539 didn't jump to line 540, because the condition on line 539 was never true

540 raise ValueError("pos attribute of the sites does not have consistent" 

541 " values.") 

542 dim = pos.shape[1] 

543 

544 def get_vec_domain(lead_nr): 

545 if lead_nr is None: 545 ↛ 546line 545 didn't jump to line 546, because the condition on line 545 was never true

546 return np.zeros((dim,)), 0 

547 if is_builder: 

548 sym = syst.leads[lead_nr].builder.symmetry 

549 try: 

550 site = syst.leads[lead_nr].interface[0] 

551 except IndexError: 

552 return (0, 0) 

553 else: 

554 try: 

555 sym = syst.leads[lead_nr].symmetry 

556 site = syst.sites[syst.lead_interfaces[lead_nr][0]] 

557 except (AttributeError, IndexError): 

558 # empty leads, or leads without symmetry aren't drawn anyways 

559 return (0, 0) 

560 dom = sym.which(site)[0] + 1 

561 # Conversion to numpy array here useful for efficiency 

562 vec = np.array(sym.periods)[0] 

563 return vec, dom 

564 vecs_doms = dict((i, get_vec_domain(i)) for i in range(len(syst.leads))) 

565 vecs_doms[None] = np.zeros((dim,)), 0 

566 for k, v in vecs_doms.items(): 

567 vecs_doms[k] = [v[0] * i for i in range(v[1], v[1] + num_lead_cells)] 

568 pos += [vecs_doms[i[1]][i[2]] for i in site_lead_nr] 

569 return pos 

570 

571 

572def sys_leads_hoppings(sys, num_lead_cells=2): 

573 """Return all the hoppings of the system and of the leads as an iterator. 

574 

575 Parameters 

576 ---------- 

577 sys : kwant.builder.Builder or kwant.system.System instance 

578 The system, sites of which should be returned. 

579 num_lead_cells : integer 

580 The number of times lead sites from each lead should be returned. 

581 This is useful for showing several unit cells of the lead next to the 

582 system. 

583 

584 Returns 

585 ------- 

586 hoppings : list of (hopping, lead_number, copy_number) tuples 

587 A site is a `~kwant.system.Site` instance if the system is not finalized, 

588 and an integer otherwise. For system sites `lead_number` is `None` and 

589 `copy_number` is `0`, for leads both are integers. 

590 lead_cells : list of slices 

591 `lead_cells[i]` gives the position of all the coordinates of lead 

592 `i` within `hoppings`. 

593 

594 Notes 

595 ----- 

596 Leads are only supported if they are of the same type as the original 

597 system, i.e. hoppings of `~kwant.builder.BuilderLead` leads are returned with an 

598 unfinalized system, and hoppings of `~kwant.system.InfiniteSystem` leads are 

599 returned with a finalized system. 

600 """ 

601 

602 syst = sys # for naming consistency inside function bodies 

603 hoppings = [] 

604 lead_cells = [] 

605 if isinstance(syst, builder.Builder): 

606 hoppings.extend(((hop, None, 0) for hop in syst.hoppings())) 

607 

608 def lead_hoppings(lead): 

609 sym = lead.symmetry 

610 for site2, site1 in lead.hoppings(): 

611 shift1 = sym.which(site1)[0] 

612 shift2 = sym.which(site2)[0] 

613 # We need to make sure that the hopping is between a site in a 

614 # fundamental domain and a site with a negative domain. The 

615 # direction of the hopping is chosen arbitrarily 

616 # NOTE(Anton): This may need to be revisited with the future 

617 # builder format changes. 

618 shift = max(shift1, shift2) 

619 yield sym.act([-shift], site2), sym.act([-shift], site1) 

620 

621 for leadnr, lead in enumerate(syst.leads): 

622 start = len(hoppings) 

623 if hasattr(lead, 'builder') and len(lead.interface): 623 ↛ 627line 623 didn't jump to line 627, because the condition on line 623 was never false

624 hoppings.extend(((hop, leadnr, i) for hop in 

625 lead_hoppings(lead.builder) for i in 

626 range(num_lead_cells))) 

627 lead_cells.append(slice(start, len(hoppings))) 

628 elif isinstance(syst, system.System): 628 ↛ 645line 628 didn't jump to line 645, because the condition on line 628 was never false

629 def ll_hoppings(syst): 

630 for i in range(syst.graph.num_nodes): 

631 for j in syst.graph.out_neighbors(i): 631 ↛ 632line 631 didn't jump to line 632, because the loop on line 631 never started

632 if i < j: 

633 yield i, j 

634 

635 hoppings.extend(((hop, None, 0) for hop in ll_hoppings(syst))) 

636 for leadnr, lead in enumerate(syst.leads): 636 ↛ 637line 636 didn't jump to line 637, because the loop on line 636 never started

637 start = len(hoppings) 

638 # We will only plot leads with a graph and with a symmetry. 

639 if (hasattr(lead, 'graph') and hasattr(lead, 'symmetry') and 

640 len(syst.lead_interfaces[leadnr])): 

641 hoppings.extend(((hop, leadnr, i) for hop in ll_hoppings(lead) 

642 for i in range(num_lead_cells))) 

643 lead_cells.append(slice(start, len(hoppings))) 

644 else: 

645 raise TypeError('Unrecognized system type.') 

646 return hoppings, lead_cells 

647 

648 

649def sys_leads_hopping_pos(sys, hop_lead_nr): 

650 """Return arrays of coordinates of all hoppings in a system. 

651 

652 Parameters 

653 ---------- 

654 sys : ``~kwant.builder.Builder`` or ``~kwant.system.System`` instance 

655 The system, coordinates of sites of which should be returned. 

656 hoppings : list of ``(hopping, leadnr, copynr)`` tuples 

657 Output of `sys_leads_hoppings` applied to the system. 

658 

659 Returns 

660 ------- 

661 coords : (end_site, start_site): tuple of NumPy arrays of floats 

662 Array of coordinates of the hoppings. The first half of coordinates 

663 in each array entry are those of the first site in the hopping, the 

664 last half are those of the second site. 

665 

666 Notes 

667 ----- 

668 This function uses ``site.pos`` property to get the position of a builder 

669 site and ``sys.pos(sitenr)`` for finalized systems. This function requires 

670 that all the positions of all the sites have the same dimensionality. 

671 """ 

672 

673 syst = sys # for naming consistency inside function bodies 

674 is_builder = isinstance(syst, builder.Builder) 

675 if len(hop_lead_nr) == 0: 

676 return np.empty((0, 3)), np.empty((0, 3)) 

677 num_lead_cells = hop_lead_nr[-1][2] + 1 

678 if is_builder: 678 ↛ 683line 678 didn't jump to line 683, because the condition on line 678 was never false

679 pos = np.array(ta.array([ta.array(tuple(i[0][0].pos) + 

680 tuple(i[0][1].pos)) for i in 

681 hop_lead_nr])) 

682 else: 

683 syst_from_lead = lambda lead: (syst if (lead is None) else 

684 syst.leads[lead]) 

685 pos = ta.array([ta.array(tuple(syst_from_lead(i[1]).pos(i[0][0])) + 

686 tuple(syst_from_lead(i[1]).pos(i[0][1]))) for i 

687 in hop_lead_nr]) 

688 pos = np.array(pos) 

689 if pos.dtype == object: # Happens if not all the pos are same length. 689 ↛ 690line 689 didn't jump to line 690, because the condition on line 689 was never true

690 raise ValueError("pos attribute of the sites does not have consistent" 

691 " values.") 

692 dim = pos.shape[1] 

693 

694 def get_vec_domain(lead_nr): 

695 if lead_nr is None: 695 ↛ 696line 695 didn't jump to line 696, because the condition on line 695 was never true

696 return np.zeros((dim,)), 0 

697 if is_builder: 697 ↛ 704line 697 didn't jump to line 704, because the condition on line 697 was never false

698 sym = syst.leads[lead_nr].builder.symmetry 

699 try: 

700 site = syst.leads[lead_nr].interface[0] 

701 except IndexError: 

702 return (0, 0) 

703 else: 

704 try: 

705 sym = syst.leads[lead_nr].symmetry 

706 site = syst.sites[syst.lead_interfaces[lead_nr][0]] 

707 except (AttributeError, IndexError): 

708 # empyt leads or leads without symmetry are not drawn anyways 

709 return (0, 0) 

710 dom = sym.which(site)[0] + 1 

711 vec = np.array(sym.periods)[0] 

712 return np.r_[vec, vec], dom 

713 

714 vecs_doms = dict((i, get_vec_domain(i)) for i in range(len(syst.leads))) 

715 vecs_doms[None] = np.zeros((dim,)), 0 

716 for k, v in vecs_doms.items(): 

717 vecs_doms[k] = [v[0] * i for i in range(v[1], v[1] + num_lead_cells)] 

718 pos += [vecs_doms[i[1]][i[2]] for i in hop_lead_nr] 

719 return np.copy(pos[:, : dim // 2]), np.copy(pos[:, dim // 2:]) 

720 

721 

722# Useful plot functions (to be extended). 

723# The default plotly symbol size is a 6 px 

724# The keys of 2, and 3 represent the dimension of the system. 

725# e.g. the default for site_size for kwant system of dim=2 is 0.25, and 

726# dim=3 is 0.5 

727defaults = {'site_symbol': {2: 'o', 3: 'o'}, 

728 'site_size': {2: 0.25, 3: 0.5}, 

729 'plotly_site_size_reference': 6, 

730 'site_color': {2: 'black', 3: 'white'}, 

731 'site_edgecolor': {2: 'black', 3: 'black'}, 

732 'site_lw': {2: 0, 3: 0.1}, 

733 'hop_color': {2: 'black', 3: 'black'}, 

734 'hop_lw': {2: 0.1, 3: 0}, 

735 'lead_color': {2: 'red', 3: 'red'}} 

736 

737 

738def plot(sys, num_lead_cells=2, unit=None, 

739 site_symbol=None, site_size=None, 

740 site_color=None, site_edgecolor=None, site_lw=None, 

741 hop_color=None, hop_lw=None, 

742 lead_site_symbol=None, lead_site_size=None, lead_color=None, 

743 lead_site_edgecolor=None, lead_site_lw=None, 

744 lead_hop_lw=None, pos_transform=None, 

745 cmap='gray', colorbar=True, file=None, 

746 show=True, dpi=None, fig_size=None, ax=None): 

747 """Plot a system in 2 or 3 dimensions. 

748 

749 An alias exists for this common name: ``kwant.plot``. 

750 

751 Parameters 

752 ---------- 

753 sys : kwant.builder.Builder or kwant.system.FiniteSystem 

754 A system to be plotted. 

755 num_lead_cells : int 

756 Number of lead copies to be shown with the system. 

757 unit : 'nn', 'pt', or float 

758 The unit used to specify symbol sizes and linewidths. 

759 Possible choices are: 

760 

761 - 'nn': unit is the shortest hopping or a typical nearst neighbor 

762 distance in the system if there are no hoppings. This means that 

763 symbol sizes/linewidths will scale as the zoom level of the figure is 

764 changed. Very short distances are discarded before searching for the 

765 shortest. This choice means that the symbols will scale if the 

766 figure is zoomed. 

767 - 'pt': unit is points (point = 1/72 inch) in figure space. This means 

768 that symbols and linewidths will always be drawn with the same size 

769 independent of zoom level of the plot. 

770 - float: sizes are given in units of this value in real (system) space, 

771 and will accordingly scale as the plot is zoomed. 

772 

773 The default value is 'nn', which allows to ensure that the images 

774 neighboring sites do not overlap. 

775 

776 site_symbol : symbol specification, function, array, or `None` 

777 Symbol used for representing a site in the plot. Can be specified as 

778 

779 - 'o': circle with radius of 1 unit. 

780 - 's': square with inner circle radius of 1 unit. 

781 - ``('p', nvert, angle)``: regular polygon with ``nvert`` vertices, 

782 rotated by ``angle``. ``angle`` is given in degrees, and ``angle=0`` 

783 corresponds to one edge of the polygon pointing upward. The 

784 radius of the inner circle is 1 unit. [Unsupported by plotly engine] 

785 - 'no symbol': no symbol is plotted. [Unsupported by plotly engine] 

786 - 'S', `('P', nvert, angle)`: as the lower-case variants described 

787 above, but with an area equal to a circle of radius 1. (Makes 

788 the visual size of the symbol equal to the size of a circle with 

789 radius 1). [Unsupported by plotly engine] 

790 - matplotlib.path.Path instance. [Unsupported by plotly engine] 

791 

792 Instead of a single symbol, different symbols can be specified 

793 for different sites by passing a function that returns a valid 

794 symbol specification for each site, or by passing an array of 

795 symbols specifications (only for kwant.system.FiniteSystem). 

796 site_size : number, function, array, or `None` 

797 Relative (linear) size of the site symbol. 

798 An array may not be used when 'syst' is a kwant.Builder. 

799 site_color : ``matplotlib`` color description, function, array, or `None` 

800 A color used for plotting a site in the system. If a colormap is used, 

801 it should be a function returning single floats or a one-dimensional 

802 array of floats. By default sites are colored by their site family, 

803 using the current matplotlib color cycle. 

804 An array of colors may not be used when 'syst' is a kwant.Builder. 

805 site_edgecolor : ``matplotlib`` color description, function, array, or `None` 

806 Color used for plotting the edges of the site symbols. Only 

807 valid matplotlib color descriptions are allowed (and no 

808 combination of floats and colormap as for site_color). 

809 An array of colors may not be used when 'syst' is a kwant.Builder. 

810 site_lw : number, function, array, or `None` 

811 Linewidth of the site symbol edges. 

812 An array may not be used when 'syst' is a kwant.Builder. 

813 hop_color : ``matplotlib`` color description or a function 

814 Same as `site_color`, but for hoppings. A function is passed two sites 

815 in this case. (arrays are not allowed in this case). 

816 hop_lw : number, function, or `None` 

817 Linewidth of the hoppings. 

818 lead_site_symbol : symbol specification or `None` 

819 Symbol to be used for the leads. See `site_symbol` for allowed 

820 specifications. Note that for leads, only constants 

821 (i.e. no functions or arrays) are allowed. If None, then 

822 `site_symbol` is used if it is constant (i.e. no function or array), 

823 the default otherwise. The same holds for the other lead properties 

824 below. 

825 lead_site_size : number or `None` 

826 Relative (linear) size of the lead symbol 

827 lead_color : ``matplotlib`` color description or `None` 

828 For the leads, `num_lead_cells` copies of the lead unit cell 

829 are plotted. They are plotted in color fading from `lead_color` 

830 to white (alpha values in `lead_color` are supported) when moving 

831 from the system into the lead. Is also applied to the 

832 hoppings. 

833 lead_site_edgecolor : ``matplotlib`` color description or `None` 

834 Color of the symbol edges (no fading done). 

835 lead_site_lw : number or `None` 

836 Linewidth of the lead symbols. 

837 lead_hop_lw : number or `None` 

838 Linewidth of the lead hoppings. 

839 cmap : ``matplotlib`` color map or a sequence of two color maps or `None` 

840 The color map used for sites and optionally hoppings. 

841 pos_transform : function or `None` 

842 Transformation to be applied to the site position. 

843 colorbar : bool 

844 Whether to show a colorbar if colormap is used. Ignored if `ax` is 

845 provided. 

846 file : string or file object or `None` 

847 The output file. If `None`, output will be shown instead. 

848 show : bool 

849 Whether ``matplotlib.pyplot.show()`` is to be called, and the output is 

850 to be shown immediately. Defaults to `True`. 

851 dpi : float or `None` 

852 Number of pixels per inch. If not set the ``matplotlib`` default is 

853 used. 

854 fig_size : tuple or `None` 

855 Figure size `(width, height)` in inches. If not set, the default 

856 ``matplotlib`` value is used. 

857 ax : ``matplotlib.axes.Axes`` instance or `None` 

858 If `ax` is not `None`, no new figure is created, but the plot is done 

859 within the existing Axes `ax`. in this case, `file`, `show`, `dpi` 

860 and `fig_size` are ignored. 

861 

862 Returns 

863 ------- 

864 fig : matplotlib figure 

865 A figure with the output if `ax` is not set, else None. 

866 

867 Notes 

868 ----- 

869 - If `None` is passed for a plot property, a default value depending on 

870 the dimension is chosen. Typically, the default values result in 

871 acceptable plots. 

872 

873 - The meaning of "site" depends on whether the system to be plotted is a 

874 builder or a low level system. For builders, a site is a 

875 kwant.system.Site object. For low level systems, a site is an integer 

876 -- the site number. 

877 

878 - color and symbol definitions may be tuples, but not lists or arrays. 

879 Arrays of values (linewidths, colors, sizes) may not be tuples. 

880 

881 - The dimensionality of the plot (2D vs 3D) is inferred from the coordinate 

882 array. If there are more than three coordinates, only the first three 

883 are used. If there is just one coordinate, the second one is padded with 

884 zeros. 

885 

886 - The system is scaled to fit the smaller dimension of the figure, given 

887 its aspect ratio. 

888 

889 """ 

890 

891 # Provide default unit if user did not specify 

892 if _p.engine == "matplotlib": 

893 fig = _plot_matplotlib(sys, num_lead_cells, unit, 

894 site_symbol, site_size, 

895 site_color, site_edgecolor, site_lw, 

896 hop_color, hop_lw, 

897 lead_site_symbol, lead_site_size, lead_color, 

898 lead_site_edgecolor, lead_site_lw, 

899 lead_hop_lw, pos_transform, 

900 cmap, colorbar, file, 

901 show, dpi, fig_size, ax) 

902 elif _p.engine == "plotly": 902 ↛ 913line 902 didn't jump to line 913, because the condition on line 902 was never false

903 _check_incompatible_args_plotly(dpi, fig_size, ax) 

904 fig = _plot_plotly(sys, num_lead_cells, unit, 

905 site_symbol, site_size, 

906 site_color, site_edgecolor, site_lw, 

907 hop_color, hop_lw, 

908 lead_site_symbol, lead_site_size, lead_color, 

909 lead_site_edgecolor, lead_site_lw, 

910 lead_hop_lw, pos_transform, 

911 cmap, colorbar, file, 

912 show) 

913 elif _p.engine is None: 

914 raise RuntimeError("Cannot use plot() without a plotting lib installed") 

915 else: 

916 raise RuntimeError("plot() does not support engine '{}'".format(_p.engine)) 

917 

918 _maybe_output_fig(fig, file=file, show=show) 

919 

920 return fig 

921 

922def _resize_to_dim(array, dim): 

923 if array.shape[1] != dim: 

924 ar = np.zeros((len(array), dim), dtype=float) 

925 ar[:, : min(dim, array.shape[1])] = array[ 

926 :, : min(dim, array.shape[1])] 

927 return ar 

928 else: 

929 return array 

930 

931 

932def _check_length(name, loc): 

933 value = loc[name] 

934 if name in ('site_size', 'site_lw') and isinstance(value, tuple): 934 ↛ 935line 934 didn't jump to line 935, because the condition on line 934 was never true

935 raise TypeError('{0} may not be a tuple, use list or ' 

936 'array instead.'.format(name)) 

937 if isinstance(value, (str, tuple)): 

938 return 

939 try: 

940 if len(value) != loc['n_syst_sites']: 940 ↛ 941line 940 didn't jump to line 941, because the condition on line 940 was never true

941 raise ValueError('Length of {0} is not equal to number of ' 

942 'system sites.'.format(name)) 

943 except TypeError: 

944 pass 

945 

946# make all specs proper: either constant or lists/np.arrays: 

947def _make_proper_site_spec(spec_name, spec, syst, sites, fancy_indexing=False): 

948 if _p.isarray(spec) and isinstance(syst, builder.Builder): 

949 raise TypeError('{} cannot be an array when plotting' 

950 ' a Builder; use a function instead.' 

951 .format(spec_name)) 

952 if callable(spec): 

953 spec = [spec(i[0]) for i in sites if i[1] is None] 

954 if (fancy_indexing and _p.isarray(spec) 954 ↛ 956line 954 didn't jump to line 956, because the condition on line 954 was never true

955 and not isinstance(spec, np.ndarray)): 

956 try: 

957 spec = np.asarray(spec) 

958 except: 

959 spec = np.asarray(spec, dtype='object') 

960 return spec 

961 

962def _make_proper_hop_spec(spec, hops, fancy_indexing=False): 

963 if callable(spec): 

964 spec = [spec(*i[0]) for i in hops if i[1] is None] 

965 if (fancy_indexing and _p.isarray(spec) 965 ↛ 967line 965 didn't jump to line 967, because the condition on line 965 was never true

966 and not isinstance(spec, np.ndarray)): 

967 try: 

968 spec = np.asarray(spec) 

969 except: 

970 spec = np.asarray(spec, dtype='object') 

971 return spec 

972 

973def _plot_plotly(sys, num_lead_cells, unit, 

974 site_symbol, site_size, 

975 site_color, site_edgecolor, site_lw, 

976 hop_color, hop_lw, 

977 lead_site_symbol, lead_site_size, lead_color, 

978 lead_site_edgecolor, lead_site_lw, 

979 lead_hop_lw, pos_transform, 

980 cmap, colorbar, file, 

981 show, fig=None): 

982 

983 if unit is None: 983 ↛ 986line 983 didn't jump to line 986, because the condition on line 983 was never false

984 unit = 'pt' 

985 

986 syst = sys # for naming consistency inside function bodies 

987 # Generate data. 

988 sites, lead_sites_slcs = sys_leads_sites(syst, num_lead_cells) 

989 n_syst_sites = sum(i[1] is None for i in sites) 

990 sites_pos = sys_leads_pos(syst, sites) 

991 hops, lead_hops_slcs = sys_leads_hoppings(syst, num_lead_cells) 

992 n_syst_hops = sum(i[1] is None for i in hops) 

993 end_pos, start_pos = sys_leads_hopping_pos(syst, hops) 

994 

995 loc = locals() 

996 

997 for name in ['site_symbol', 'site_size', 'site_color', 'site_edgecolor', 

998 'site_lw']: 

999 _check_length(name, loc) 

1000 

1001 # Apply transformations to the data 

1002 if pos_transform is not None: 

1003 sites_pos = np.apply_along_axis(pos_transform, 1, sites_pos) 

1004 end_pos = np.apply_along_axis(pos_transform, 1, end_pos) 

1005 start_pos = np.apply_along_axis(pos_transform, 1, start_pos) 

1006 

1007 dim = 3 if (sites_pos.shape[1] == 3) else 2 

1008 

1009 sites_pos = _resize_to_dim(sites_pos, dim) 

1010 end_pos = _resize_to_dim(end_pos, dim) 

1011 start_pos = _resize_to_dim(start_pos, dim) 

1012 

1013 # Determine the reference length. 

1014 if unit != 'pt': 1014 ↛ 1015line 1014 didn't jump to line 1015, because the condition on line 1014 was never true

1015 raise RuntimeError('Plotly engine currently only supports ' 

1016 'the pt symbol size unit') 

1017 

1018 site_symbol = _make_proper_site_spec('site_symbol', site_symbol, syst, sites) 

1019 if site_symbol is None: site_symbol = defaults['site_symbol'][dim] 

1020 # separate different symbols (not done in 3D, the separation 

1021 # would mess up sorting) 

1022 if (_p.isarray(site_symbol) and dim != 3 and 1022 ↛ 1024line 1022 didn't jump to line 1024, because the condition on line 1022 was never true

1023 (len(site_symbol) != 3 or site_symbol[0] not in ('p', 'P'))): 

1024 symbol_dict = defaultdict(list) 

1025 for i, symbol in enumerate(site_symbol): 

1026 symbol_dict[symbol].append(i) 

1027 symbol_slcs = [] 

1028 for symbol, indx in symbol_dict.items(): 

1029 symbol_slcs.append((symbol, np.array(indx))) 

1030 fancy_indexing = True 

1031 else: 

1032 symbol_slcs = [(site_symbol, slice(n_syst_sites))] 

1033 fancy_indexing = False 

1034 

1035 if site_color is None: 

1036 cycle = _color_cycle() 

1037 if isinstance(syst, (builder.FiniteSystem, builder.InfiniteSystem)): 

1038 # Skipping the leads for brevity. 

1039 families = sorted({site.family for site in syst.sites}) 

1040 color_mapping = dict(zip(families, cycle)) 

1041 def site_color(site): 

1042 return color_mapping[syst.sites[site].family] 

1043 elif isinstance(syst, builder.Builder): 1043 ↛ 1050line 1043 didn't jump to line 1050, because the condition on line 1043 was never false

1044 families = sorted({site[0].family for site in sites}) 

1045 color_mapping = dict(zip(families, cycle)) 

1046 def site_color(site): 

1047 return color_mapping[site.family] 

1048 else: 

1049 # Unknown finalized system, no sites access. 

1050 site_color = defaults['site_color'][dim] 

1051 

1052 site_size = _make_proper_site_spec('site_size',site_size, syst, sites, fancy_indexing) 

1053 site_color = _make_proper_site_spec('site_color',site_color, syst, sites, fancy_indexing) 

1054 site_edgecolor = _make_proper_site_spec('site_edgecolor',site_edgecolor, syst, sites, 

1055 fancy_indexing) 

1056 site_lw = _make_proper_site_spec('site_lw',site_lw, syst, sites, fancy_indexing) 

1057 

1058 hop_color = _make_proper_hop_spec(hop_color, hops) 

1059 hop_lw = _make_proper_hop_spec(hop_lw, hops) 

1060 

1061 # Choose defaults depending on dimension, if None was given 

1062 if site_size is None: site_size = defaults['site_size'][dim] 

1063 if site_edgecolor is None: 1063 ↛ 1065line 1063 didn't jump to line 1065, because the condition on line 1063 was never false

1064 site_edgecolor = defaults['site_edgecolor'][dim] 

1065 if site_lw is None: site_lw = defaults['site_lw'][dim] 

1066 

1067 if hop_color is None: hop_color = defaults['hop_color'][dim] 

1068 if hop_lw is None: hop_lw = defaults['hop_lw'][dim] 

1069 

1070 if len(symbol_slcs) > 1: 1070 ↛ 1071line 1070 didn't jump to line 1071, because the condition on line 1070 was never true

1071 try: 

1072 if site_color.ndim == 1 and len(site_color) == n_syst_sites: 

1073 site_color = np.asarray(site_color, dtype=float) 

1074 except: 

1075 pass 

1076 

1077 # take spec also for lead, if it's not a list/array, default, otherwise 

1078 if lead_site_symbol is None: 1078 ↛ 1081line 1078 didn't jump to line 1081, because the condition on line 1078 was never false

1079 lead_site_symbol = (site_symbol if not _p.isarray(site_symbol) 

1080 else defaults['site_symbol'][dim]) 

1081 if lead_site_size is None: 1081 ↛ 1084line 1081 didn't jump to line 1084, because the condition on line 1081 was never false

1082 lead_site_size = (site_size if not _p.isarray(site_size) 

1083 else defaults['site_size'][dim]) 

1084 if lead_color is None: 1084 ↛ 1086line 1084 didn't jump to line 1086, because the condition on line 1084 was never false

1085 lead_color = defaults['lead_color'][dim] 

1086 lead_color = _p.matplotlib.colors.colorConverter.to_rgba(lead_color) 

1087 

1088 if lead_site_edgecolor is None: 1088 ↛ 1091line 1088 didn't jump to line 1091, because the condition on line 1088 was never false

1089 lead_site_edgecolor = (site_edgecolor if not _p.isarray(site_edgecolor) 

1090 else defaults['site_edgecolor'][dim]) 

1091 if lead_site_lw is None: 1091 ↛ 1094line 1091 didn't jump to line 1094, because the condition on line 1091 was never false

1092 lead_site_lw = (site_lw if not _p.isarray(site_lw) 

1093 else defaults['site_lw'][dim]) 

1094 if lead_hop_lw is None: 1094 ↛ 1098line 1094 didn't jump to line 1098, because the condition on line 1094 was never false

1095 lead_hop_lw = (hop_lw if not _p.isarray(hop_lw) 

1096 else defaults['hop_lw'][dim]) 

1097 

1098 hop_cmap = None 

1099 if not isinstance(cmap, str): 1099 ↛ 1100line 1099 didn't jump to line 1100, because the condition on line 1099 was never true

1100 try: 

1101 cmap, hop_cmap = cmap 

1102 except TypeError: 

1103 pass 

1104 # Plot system sites and hoppings 

1105 

1106 # First plot the nodes (sites) of the graph 

1107 assert dim == 2 or dim == 3 

1108 site_node_trace, site_edge_trace = [], [] 

1109 for symbol, slc in symbol_slcs: 

1110 site_symbol_plotly = _p.convert_symbol_mpl_plotly(symbol) 

1111 if site_symbol_plotly == -1: 1111 ↛ 1115line 1111 didn't jump to line 1115, because the condition on line 1111 was never true

1112 # The kwant documentation supports no symbol as a string argument for site_symbol 

1113 # If it evaluates to -1, then the user has specified "no symbol" as the input. 

1114 # https://kwant-project.org/doc/1/reference/generated/kwant.plotter.plot 

1115 continue 

1116 size = site_size[slc] if _p.isarray(site_size) else site_size 

1117 col = site_color[slc] if _p.isarray(site_color) else site_color 

1118 if _p.isarray(site_edgecolor) or _p.isarray(site_lw): 1118 ↛ 1119line 1118 didn't jump to line 1119, because the condition on line 1118 was never true

1119 raise RuntimeError("Plotly engine not currently support an array " 

1120 "of linecolors or linewidths. Please restrict " 

1121 "to only a constant (i.e. no function or array)" 

1122 " site_edgecolor and site_lw property " 

1123 "for the entire plot.") 

1124 else: 

1125 edgecol = site_edgecolor if not isinstance(site_edgecolor, tuple) \ 

1126 else _p.convert_colormap_mpl_plotly(*site_edgecolor) 

1127 lw = site_lw 

1128 

1129 if dim == 3: 

1130 x, y, z = sites_pos[slc].transpose() 

1131 site_node_trace_elem = _p.plotly_graph_objs.Scatter3d(x=x, y=y, 

1132 z=z) 

1133 site_node_trace_elem.marker.symbol = _p.convert_symbol_mpl_plotly_3d( 

1134 symbol) 

1135 else: 

1136 x, y = sites_pos[slc].transpose() 

1137 site_node_trace_elem = _p.plotly_graph_objs.Scatter(x=x, y=y) 

1138 site_node_trace_elem.marker.symbol = _p.convert_symbol_mpl_plotly( 

1139 symbol) 

1140 

1141 site_node_trace_elem.mode = 'markers' 

1142 site_node_trace_elem.hoverinfo = 'none' 

1143 site_node_trace_elem.marker.showscale = False 

1144 site_node_trace_elem.marker.colorscale = \ 

1145 _p.convert_cmap_list_mpl_plotly(cmap) 

1146 site_node_trace_elem.marker.reversescale = False 

1147 marker_color = col if not isinstance(col, tuple) \ 

1148 else _p.convert_colormap_mpl_plotly(*col) 

1149 site_node_trace_elem.marker.color = marker_color 

1150 site_node_trace_elem.marker.size = \ 

1151 _p.convert_site_size_mpl_plotly(size, 

1152 defaults['plotly_site_size_reference']) 

1153 

1154 site_node_trace_elem.line.width = lw 

1155 site_node_trace_elem.line.color = edgecol 

1156 site_node_trace_elem.showlegend = False 

1157 

1158 site_node_trace.append(site_node_trace_elem) 

1159 

1160 # Now plot the edges (hops) of the graph 

1161 end, start = end_pos[: n_syst_hops], start_pos[: n_syst_hops] 

1162 

1163 if dim == 3: 

1164 x0, y0, z0 = end.transpose() 

1165 x1, y1, z1 = start.transpose() 

1166 nones = [None] * len(x0) 

1167 site_edge_trace_elem = _p.plotly_graph_objs.Scatter3d( 

1168 x=np.array([x0, x1, nones]).transpose().flatten(), 

1169 y=np.array([y0, y1, nones]).transpose().flatten(), 

1170 z=np.array([z0, z1, nones]).transpose().flatten() 

1171 ) 

1172 else: 

1173 x0, y0 = end.transpose() 

1174 x1, y1 = start.transpose() 

1175 nones = [None] * len(x0) 

1176 site_edge_trace_elem = _p.plotly_graph_objs.Scatter( 

1177 x=np.array([x0, x1, nones]).transpose().flatten(), 

1178 y=np.array([y0, y1, nones]).transpose().flatten() 

1179 ) 

1180 

1181 if _p.isarray(hop_color) or _p.isarray(hop_lw): 1181 ↛ 1182line 1181 didn't jump to line 1182, because the condition on line 1181 was never true

1182 raise RuntimeError("Plotly engine not currently support an array " 

1183 "of linecolors or linewidths. Please restrict " 

1184 "to only a constant (i.e. no function or array)" 

1185 " hop_color and hop_lw property " 

1186 "for the entire plot.") 

1187 site_edge_trace_elem.line.width = hop_lw 

1188 site_edge_trace_elem.line.color = hop_color 

1189 site_edge_trace_elem.hoverinfo = 'none' 

1190 site_edge_trace_elem.showlegend = False 

1191 site_edge_trace_elem.mode = 'lines' 

1192 site_edge_trace.append(site_edge_trace_elem) 

1193 

1194 # Plot lead sites and edges 

1195 

1196 lead_node_trace, lead_edge_trace = [], [] 

1197 for sites_slc, hops_slc in zip(lead_sites_slcs, lead_hops_slcs): 

1198 lead_site_colors = np.array([i[2] for i in sites[sites_slc]], 

1199 dtype=float) 

1200 if dim == 3: 

1201 

1202 x, y, z = sites_pos[sites_slc].transpose() 

1203 lead_node_trace_elem = _p.plotly_graph_objs.Scatter3d(x=x, y=y, 

1204 z=z) 

1205 lead_node_trace_elem.marker.symbol = \ 

1206 _p.convert_symbol_mpl_plotly_3d(lead_site_symbol) 

1207 else: 

1208 x, y = sites_pos[sites_slc].transpose() 

1209 lead_node_trace_elem = _p.plotly_graph_objs.Scatter(x=x, y=y) 

1210 lead_site_symbol_plotly = _p.convert_symbol_mpl_plotly(lead_site_symbol) 

1211 if lead_site_symbol_plotly == -1: 1211 ↛ 1215line 1211 didn't jump to line 1215, because the condition on line 1211 was never true

1212 # The kwant documentation supports no symbol as a string argument for site_symbol 

1213 # If it evaluates to -1, then the user has specified "no symbol" as the input. 

1214 # https://kwant-project.org/doc/1/reference/generated/kwant.plotter.plot 

1215 continue 

1216 lead_node_trace_elem.marker.symbol = lead_site_symbol_plotly 

1217 

1218 lead_node_trace_elem.mode = 'markers' 

1219 lead_node_trace_elem.hoverinfo = 'none' 

1220 lead_node_trace_elem.showlegend = False 

1221 lead_node_trace_elem.marker.showscale = False 

1222 lead_node_trace_elem.marker.reversescale = False 

1223 lead_node_trace_elem.marker.color = lead_site_colors 

1224 lead_node_trace_elem.marker.colorscale = \ 

1225 _p.convert_lead_cmap_mpl_plotly(lead_color, 

1226 [1, 1, 1, lead_color[3]]) 

1227 lead_node_trace_elem.marker.size = _p.convert_site_size_mpl_plotly( 

1228 lead_site_size, 

1229 defaults['plotly_site_size_reference']) 

1230 

1231 if _p.isarray(lead_site_lw) or _p.isarray(lead_site_edgecolor): 1231 ↛ 1232line 1231 didn't jump to line 1232, because the condition on line 1231 was never true

1232 raise RuntimeError("Plotly engine not currently support an array " 

1233 "of linecolors or linewidths. Please restrict " 

1234 "to only a constant (i.e. no function or array) " 

1235 "lead_site_lw and lead_site_edgecolor property " 

1236 "for the entire plot.") 

1237 lead_node_trace_elem.line.width = lead_site_lw 

1238 lead_node_trace_elem.line.color = lead_site_edgecolor 

1239 

1240 if lead_node_trace_elem: 1240 ↛ 1243line 1240 didn't jump to line 1243, because the condition on line 1240 was never false

1241 lead_node_trace.append(lead_node_trace_elem) 

1242 

1243 lead_hop_colors = np.array([i[2] for i in hops[hops_slc]], dtype=float) 

1244 

1245 end, start = end_pos[hops_slc], start_pos[hops_slc] 

1246 

1247 if dim == 3: 

1248 x0, y0, z0 = end.transpose() 

1249 x1, y1, z1 = start.transpose() 

1250 nones = [None] * len(x0) 

1251 lead_edge_trace_elem = _p.plotly_graph_objs.Scatter3d( 

1252 x=np.array([x0, x1, nones]).transpose().flatten(), 

1253 y=np.array([y0, y1, nones]).transpose().flatten(), 

1254 z=np.array([z0, z1, nones]).transpose().flatten() 

1255 ) 

1256 

1257 else: 

1258 x0, y0 = end.transpose() 

1259 x1, y1 = start.transpose() 

1260 nones = [None] * len(x0) 

1261 lead_edge_trace_elem = _p.plotly_graph_objs.Scatter( 

1262 x=np.array([x0, x1, nones]).transpose().flatten(), 

1263 y=np.array([y0, y1, nones]).transpose().flatten() 

1264 ) 

1265 

1266 lead_edge_trace_elem.line.width = lead_hop_lw 

1267 lead_edge_trace_elem.line.color = _p.convert_colormap_mpl_plotly(*lead_color) 

1268 lead_edge_trace_elem.hoverinfo = 'none' 

1269 lead_edge_trace_elem.mode = 'lines' 

1270 lead_edge_trace_elem.showlegend = False 

1271 

1272 lead_edge_trace.append(lead_edge_trace_elem) 

1273 

1274 layout = _p.plotly_graph_objs.Layout( 

1275 showlegend=False, 

1276 hovermode='closest', 

1277 xaxis=dict(showgrid=False, zeroline=False, 

1278 showticklabels=True), 

1279 yaxis=dict(showgrid=False, zeroline=False, 

1280 showticklabels=True)) 

1281 if fig is None: 1281 ↛ 1288line 1281 didn't jump to line 1288, because the condition on line 1281 was never false

1282 full_trace = list(itertools.chain.from_iterable([site_edge_trace, 

1283 site_node_trace, lead_edge_trace, 

1284 lead_node_trace])) 

1285 fig = _p.plotly_graph_objs.Figure(data=full_trace, 

1286 layout=layout) 

1287 else: 

1288 full_trace = list(itertools.chain.from_iterable([lead_edge_trace, 

1289 lead_node_trace])) 

1290 for trace in full_trace: 

1291 try: 

1292 fig.add_trace(trace) 

1293 except TypeError: 

1294 fig.data += [trace] 

1295 

1296 return fig 

1297 

1298 

1299def _plot_matplotlib(sys, num_lead_cells, unit, 

1300 site_symbol, site_size, 

1301 site_color, site_edgecolor, site_lw, 

1302 hop_color, hop_lw, 

1303 lead_site_symbol, lead_site_size, lead_color, 

1304 lead_site_edgecolor, lead_site_lw, 

1305 lead_hop_lw, pos_transform, 

1306 cmap, colorbar, file, 

1307 show, dpi, fig_size, ax): 

1308 

1309 if unit is None: 1309 ↛ 1312line 1309 didn't jump to line 1312, because the condition on line 1309 was never false

1310 unit = 'nn' 

1311 

1312 syst = sys # for naming consistency inside function bodies 

1313 # Generate data. 

1314 sites, lead_sites_slcs = sys_leads_sites(syst, num_lead_cells) 

1315 n_syst_sites = sum(i[1] is None for i in sites) 

1316 sites_pos = sys_leads_pos(syst, sites) 

1317 hops, lead_hops_slcs = sys_leads_hoppings(syst, num_lead_cells) 

1318 n_syst_hops = sum(i[1] is None for i in hops) 

1319 end_pos, start_pos = sys_leads_hopping_pos(syst, hops) 

1320 

1321 loc = locals() 

1322 

1323 for name in ['site_symbol', 'site_size', 'site_color', 'site_edgecolor', 

1324 'site_lw']: 

1325 _check_length(name, loc) 

1326 

1327 # Apply transformations to the data 

1328 if pos_transform is not None: 

1329 sites_pos = np.apply_along_axis(pos_transform, 1, sites_pos) 

1330 end_pos = np.apply_along_axis(pos_transform, 1, end_pos) 

1331 start_pos = np.apply_along_axis(pos_transform, 1, start_pos) 

1332 

1333 dim = 3 if (sites_pos.shape[1] == 3) else 2 

1334 if dim == 3 and not _p.has3d: 1334 ↛ 1335line 1334 didn't jump to line 1335, because the condition on line 1334 was never true

1335 raise RuntimeError("Installed matplotlib does not support 3d plotting") 

1336 sites_pos = _resize_to_dim(sites_pos, dim) 

1337 end_pos = _resize_to_dim(end_pos, dim) 

1338 start_pos = _resize_to_dim(start_pos, dim) 

1339 

1340 # Determine the reference length. 

1341 if unit == 'pt': 1341 ↛ 1342line 1341 didn't jump to line 1342, because the condition on line 1341 was never true

1342 reflen = None 

1343 elif unit == 'nn': 1343 ↛ 1366line 1343 didn't jump to line 1366, because the condition on line 1343 was never false

1344 if n_syst_hops: 

1345 # If hoppings are present use their lengths to determine the 

1346 # minimal one. 

1347 distances = end_pos - start_pos 

1348 else: 

1349 # If no hoppings are present, use for the same purpose distances 

1350 # from ten randomly selected points to the remaining points in the 

1351 # system. 

1352 points = _sample_array(sites_pos, 10).T 

1353 distances = (sites_pos.reshape(1, -1, dim) - 

1354 points.reshape(-1, 1, dim)).reshape(-1, dim) 

1355 distances = np.sort(np.sum(distances**2, axis=1)) 

1356 # Then check if distances are present that are way shorter than the 

1357 # longest one. Then take first distance longer than these short 

1358 # ones. This heuristic will fail for too large systems, or systems with 

1359 # hoppings that vary by orders and orders of magnitude, but for sane 

1360 # cases it will work. 

1361 long_dist_coord = np.searchsorted(distances, 1e-16 * distances[-1]) 

1362 reflen = sqrt(distances[long_dist_coord]) 

1363 

1364 else: 

1365 # The last allowed value is float-compatible. 

1366 try: 

1367 reflen = float(unit) 

1368 except: 

1369 raise ValueError('Invalid value of unit argument.') 

1370 

1371 site_symbol = _make_proper_site_spec('site_symbol', site_symbol, syst, sites) 

1372 if site_symbol is None: site_symbol = defaults['site_symbol'][dim] 

1373 # separate different symbols (not done in 3D, the separation 

1374 # would mess up sorting) 

1375 if (_p.isarray(site_symbol) and dim != 3 and 1375 ↛ 1377line 1375 didn't jump to line 1377, because the condition on line 1375 was never true

1376 (len(site_symbol) != 3 or site_symbol[0] not in ('p', 'P'))): 

1377 symbol_dict = defaultdict(list) 

1378 for i, symbol in enumerate(site_symbol): 

1379 symbol_dict[symbol].append(i) 

1380 symbol_slcs = [] 

1381 for symbol, indx in symbol_dict.items(): 

1382 symbol_slcs.append((symbol, np.array(indx))) 

1383 fancy_indexing = True 

1384 else: 

1385 symbol_slcs = [(site_symbol, slice(n_syst_sites))] 

1386 fancy_indexing = False 

1387 

1388 if site_color is None: 

1389 cycle = _color_cycle() 

1390 if builder.is_system(syst): 

1391 # Skipping the leads for brevity. 

1392 families = sorted({site.family for site in syst.sites}) 

1393 color_mapping = dict(zip(families, cycle)) 

1394 def site_color(site): 

1395 return color_mapping[syst.sites[site].family] 

1396 elif isinstance(syst, builder.Builder): 1396 ↛ 1403line 1396 didn't jump to line 1403, because the condition on line 1396 was never false

1397 families = sorted({site[0].family for site in sites}) 

1398 color_mapping = dict(zip(families, cycle)) 

1399 def site_color(site): 

1400 return color_mapping[site.family] 

1401 else: 

1402 # Unknown finalized system, no sites access. 

1403 site_color = defaults['site_color'][dim] 

1404 

1405 site_size = _make_proper_site_spec('site_size', site_size, syst, sites, fancy_indexing) 

1406 site_color = _make_proper_site_spec('site_color', site_color, syst, sites, fancy_indexing) 

1407 site_edgecolor = _make_proper_site_spec('site_edgecolor', site_edgecolor, syst, sites, fancy_indexing) 

1408 site_lw = _make_proper_site_spec('site_lw', site_lw, syst, sites, fancy_indexing) 

1409 

1410 hop_color = _make_proper_hop_spec(hop_color, hops) 

1411 hop_lw = _make_proper_hop_spec(hop_lw, hops) 

1412 

1413 # Choose defaults depending on dimension, if None was given 

1414 if site_size is None: site_size = defaults['site_size'][dim] 

1415 if site_edgecolor is None: 1415 ↛ 1417line 1415 didn't jump to line 1417, because the condition on line 1415 was never false

1416 site_edgecolor = defaults['site_edgecolor'][dim] 

1417 if site_lw is None: site_lw = defaults['site_lw'][dim] 

1418 

1419 if hop_color is None: hop_color = defaults['hop_color'][dim] 

1420 if hop_lw is None: hop_lw = defaults['hop_lw'][dim] 

1421 

1422 # if symbols are split up into different collections, 

1423 # the colormapping will fail without normalization 

1424 norm = None 

1425 if len(symbol_slcs) > 1: 1425 ↛ 1426line 1425 didn't jump to line 1426, because the condition on line 1425 was never true

1426 try: 

1427 if site_color.ndim == 1 and len(site_color) == n_syst_sites: 

1428 site_color = np.asarray(site_color, dtype=float) 

1429 norm = _p.matplotlib.colors.Normalize(site_color.min(), 

1430 site_color.max()) 

1431 except: 

1432 pass 

1433 

1434 # take spec also for lead, if it's not a list/array, default, otherwise 

1435 if lead_site_symbol is None: 1435 ↛ 1438line 1435 didn't jump to line 1438, because the condition on line 1435 was never false

1436 lead_site_symbol = (site_symbol if not _p.isarray(site_symbol) 

1437 else defaults['site_symbol'][dim]) 

1438 if lead_site_size is None: 1438 ↛ 1441line 1438 didn't jump to line 1441, because the condition on line 1438 was never false

1439 lead_site_size = (site_size if not _p.isarray(site_size) 

1440 else defaults['site_size'][dim]) 

1441 if lead_color is None: 1441 ↛ 1443line 1441 didn't jump to line 1443, because the condition on line 1441 was never false

1442 lead_color = defaults['lead_color'][dim] 

1443 lead_color = _p.matplotlib.colors.colorConverter.to_rgba(lead_color) 

1444 

1445 if lead_site_edgecolor is None: 1445 ↛ 1448line 1445 didn't jump to line 1448, because the condition on line 1445 was never false

1446 lead_site_edgecolor = (site_edgecolor if not _p.isarray(site_edgecolor) 

1447 else defaults['site_edgecolor'][dim]) 

1448 if lead_site_lw is None: 1448 ↛ 1451line 1448 didn't jump to line 1451, because the condition on line 1448 was never false

1449 lead_site_lw = (site_lw if not _p.isarray(site_lw) 

1450 else defaults['site_lw'][dim]) 

1451 if lead_hop_lw is None: 1451 ↛ 1455line 1451 didn't jump to line 1455, because the condition on line 1451 was never false

1452 lead_hop_lw = (hop_lw if not _p.isarray(hop_lw) 

1453 else defaults['hop_lw'][dim]) 

1454 

1455 hop_cmap = None 

1456 if not isinstance(cmap, str): 1456 ↛ 1457line 1456 didn't jump to line 1457, because the condition on line 1456 was never true

1457 try: 

1458 cmap, hop_cmap = cmap 

1459 except TypeError: 

1460 pass 

1461 

1462 # make a new figure unless axes specified 

1463 if not ax: 1463 ↛ 1474line 1463 didn't jump to line 1474, because the condition on line 1463 was never false

1464 fig = _make_figure(dpi, fig_size, use_pyplot=(file is None)) 

1465 if dim == 2: 

1466 ax = fig.add_subplot(1, 1, 1, aspect='equal') 

1467 ax.set_xmargin(0.05) 

1468 ax.set_ymargin(0.05) 

1469 else: 

1470 warnings.filterwarnings('ignore', message=r'.*rotation.*') 

1471 ax = fig.add_subplot(1, 1, 1, projection='3d') 

1472 warnings.resetwarnings() 

1473 else: 

1474 fig = None 

1475 

1476 # plot system sites and hoppings 

1477 for symbol, slc in symbol_slcs: 

1478 size = site_size[slc] if _p.isarray(site_size) else site_size 

1479 col = site_color[slc] if _p.isarray(site_color) else site_color 

1480 edgecol = (site_edgecolor[slc] if _p.isarray(site_edgecolor) else 

1481 site_edgecolor) 

1482 lw = site_lw[slc] if _p.isarray(site_lw) else site_lw 

1483 

1484 symbol_coll = symbols(ax, sites_pos[slc], size=size, 

1485 reflen=reflen, symbol=symbol, 

1486 facecolor=col, edgecolor=edgecol, 

1487 linewidth=lw, cmap=cmap, norm=norm, zorder=2) 

1488 

1489 end, start = end_pos[: n_syst_hops], start_pos[: n_syst_hops] 

1490 line_coll = lines(ax, end, start, reflen, hop_color, linewidths=hop_lw, 

1491 zorder=1, cmap=hop_cmap) 

1492 

1493 # plot lead sites and hoppings 

1494 norm = _p.matplotlib.colors.Normalize(-0.5, num_lead_cells - 0.5) 

1495 cmap_from_list = _p.matplotlib.colors.LinearSegmentedColormap.from_list 

1496 lead_cmap = cmap_from_list(None, [lead_color, (1, 1, 1, lead_color[3])]) 

1497 

1498 for sites_slc, hops_slc in zip(lead_sites_slcs, lead_hops_slcs): 

1499 lead_site_colors = np.array([i[2] for i in sites[sites_slc]], 

1500 dtype=float) 

1501 

1502 # Note: the previous version of the code had in addition this 

1503 # line in the 3D case: 

1504 # lead_site_colors = 1 / np.sqrt(1. + lead_site_colors) 

1505 symbols(ax, sites_pos[sites_slc], size=lead_site_size, reflen=reflen, 

1506 symbol=lead_site_symbol, facecolor=lead_site_colors, 

1507 edgecolor=lead_site_edgecolor, linewidth=lead_site_lw, 

1508 cmap=lead_cmap, zorder=2, norm=norm) 

1509 

1510 lead_hop_colors = np.array([i[2] for i in hops[hops_slc]], dtype=float) 

1511 

1512 # Note: the previous version of the code had in addition this 

1513 # line in the 3D case: 

1514 # lead_hop_colors = 1 / np.sqrt(1. + lead_hop_colors) 

1515 end, start = end_pos[hops_slc], start_pos[hops_slc] 

1516 lines(ax, end, start, reflen, lead_hop_colors, linewidths=lead_hop_lw, 

1517 cmap=lead_cmap, norm=norm, zorder=1) 

1518 

1519 min_ = np.min(sites_pos, 0) 

1520 max_ = np.max(sites_pos, 0) 

1521 m = (min_ + max_) / 2 

1522 if dim == 2: 

1523 w = np.max([(max_ - min_) / 2, (reflen, reflen)], axis=0) 

1524 ax.update_datalim((m - w, m + w)) 

1525 ax.autoscale_view(tight=True) 

1526 else: 

1527 # make axis limits the same in all directions 

1528 # (3D only works decently for equal aspect ratio. Since 

1529 # this doesn't work out of the box in mplot3d, this is a 

1530 # workaround) 

1531 w = np.max(max_ - min_) / 2 

1532 ax.auto_scale_xyz(*[(i - w, i + w) for i in m], had_data=True) 

1533 

1534 # add separate colorbars for symbols and hoppings if ncessary 

1535 if symbol_coll.get_array() is not None and colorbar and fig is not None: 

1536 fig.colorbar(symbol_coll) 

1537 if line_coll.get_array() is not None and colorbar and fig is not None: 

1538 fig.colorbar(line_coll) 

1539 

1540 return fig 

1541 

1542 

1543def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3): 

1544 """Interpolate a scalar function in vicinity of given points. 

1545 

1546 Create a masked array corresponding to interpolated values of the function 

1547 at points lying not further than a certain distance from the original 

1548 data points provided. 

1549 

1550 Parameters 

1551 ---------- 

1552 coords : np.ndarray 

1553 An array with site coordinates. 

1554 values : np.ndarray 

1555 An array with the values from which the interpolation should be built. 

1556 a : float, optional 

1557 Reference length. If not given, it is determined as a typical 

1558 nearest neighbor distance. 

1559 method : string, optional 

1560 Passed to ``scipy.interpolate.griddata``: "nearest" (default), "linear", 

1561 or "cubic" 

1562 oversampling : integer, optional 

1563 Number of pixels per reference length. Defaults to 3. 

1564 

1565 Returns 

1566 ------- 

1567 array : 2d NumPy array 

1568 The interpolated values. 

1569 min, max : vectors 

1570 The real-space coordinates of the two extreme ([0, 0] and [-1, -1]) 

1571 points of ``array``. 

1572 

1573 Notes 

1574 ----- 

1575 - `min` and `max` are chosen such that when plotting a system on a square 

1576 lattice and `oversampling` is set to an odd integer, each site will lie 

1577 exactly at the center of a pixel of the output array. 

1578 

1579 - When plotting a system on a square lattice and `method` is "nearest", it 

1580 makes sense to set `oversampling` to ``1``. Then, each site will 

1581 correspond to exactly one pixel in the resulting array. 

1582 """ 

1583 # Build the bounding box. 

1584 cmin, cmax = coords.min(0), coords.max(0) 

1585 

1586 tree = spatial.cKDTree(coords) 

1587 

1588 # Select 10 sites to compare -- comparing them all is too costly. 

1589 points = _sample_array(coords, 10) 

1590 min_dist = np.min(tree.query(points, 2)[0][:, 1]) 

1591 if min_dist < 1e-6 * np.linalg.norm(cmax - cmin): 

1592 warnings.warn("Some sites have nearly coinciding positions, " 

1593 "interpolation may be confusing.", 

1594 RuntimeWarning, stacklevel=2) 

1595 

1596 if a is None: 

1597 a = min_dist 

1598 

1599 if a < 1e-6 * np.linalg.norm(cmax - cmin): 

1600 raise ValueError("The reference distance a is too small.") 

1601 

1602 if len(coords) != len(values): 1602 ↛ 1603line 1602 didn't jump to line 1603, because the condition on line 1602 was never true

1603 raise ValueError("The number of sites doesn't match the number of " 

1604 "provided values.") 

1605 

1606 shape = (((cmax - cmin) / a + 1) * oversampling).round() 

1607 delta = 0.5 * (oversampling - 1) * a / oversampling 

1608 cmin -= delta 

1609 cmax += delta 

1610 dims = tuple(slice(cmin[i], cmax[i], 1j * shape[i]) for i in 

1611 range(len(cmin))) 

1612 grid = tuple(np.ogrid[dims]) 

1613 img = interpolate.griddata(coords, values, grid, method) 

1614 img = img.astype(np.float_) 

1615 mask = np.mgrid[dims].reshape(len(cmin), -1).T 

1616 # The numerical values in the following line are optimized for the common 

1617 # case of a square lattice: 

1618 # * 0.99 makes sure that non-masked pixels and sites correspond 1-by-1 to 

1619 # each other when oversampling == 1. 

1620 # * 0.4 (which is just below sqrt(2) - 1) makes tree.query() exact. 

1621 mask = tree.query(mask, eps=0.4)[0] > 0.99 * a 

1622 

1623 masked_result_array = np.ma.masked_array(img, mask) 

1624 

1625 try: 

1626 if _p.engine != "matplotlib": 

1627 result_array = masked_result_array.filled(np.NaN) 

1628 else: 

1629 result_array = masked_result_array 

1630 except AttributeError: 

1631 result_array = masked_result_array 

1632 

1633 return result_array, img, cmin, cmax 

1634 

1635 

1636def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, 

1637 method='nearest', oversampling=3, num_lead_cells=0, file=None, 

1638 show=True, dpi=None, fig_size=None, ax=None, pos_transform=None, 

1639 background='#e0e0e0'): 

1640 """Show interpolated map of a function defined for the sites of a system. 

1641 

1642 Create a pixmap representation of a function of the sites of a system by 

1643 calling `~kwant.plotter.mask_interpolate` and show this pixmap using 

1644 matplotlib. 

1645 

1646 This function is similar to `~kwant.plotter.density`, but is more suited 

1647 to the case where you want site-level resolution of the quantity that 

1648 you are plotting. If your system has many sites you may get more appealing 

1649 plots by using `~kwant.plotter.density`. 

1650 

1651 Parameters 

1652 ---------- 

1653 sys : kwant.system.FiniteSystem or kwant.builder.Builder 

1654 The system for whose sites `value` is to be plotted. 

1655 value : function or list 

1656 Function which takes a site and returns a value if the system is a 

1657 builder, or a list of function values for each system site of the 

1658 finalized system. 

1659 colorbar : bool, optional 

1660 Whether to show a color bar if numerical data has to be plotted. 

1661 Defaults to `True`. If `ax` is provided, the colorbar is never plotted. 

1662 cmap : ``matplotlib`` color map or `None` 

1663 The color map used for sites and optionally hoppings, if `None`, 

1664 ``matplotlib`` default is used. 

1665 vmin : float, optional 

1666 The lower saturation limit for the colormap; values returned by 

1667 `value` which are smaller than this will saturate 

1668 vmax : float, optional 

1669 The upper saturation limit for the colormap; valued returned by 

1670 `value` which are larger than this will saturate 

1671 a : float, optional 

1672 Reference length. If not given, it is determined as a typical 

1673 nearest neighbor distance. 

1674 method : string, optional 

1675 Passed to ``scipy.interpolate.griddata``: "nearest" (default), "linear", 

1676 or "cubic" 

1677 oversampling : integer, optional 

1678 Number of pixels per reference length. Defaults to 3. 

1679 num_lead_cells : integer, optional 

1680 number of lead unit cells that should be plotted to indicate 

1681 the position of leads. Defaults to 0. 

1682 file : string or file object or `None` 

1683 The output file. If `None`, output will be shown instead. 

1684 show : bool 

1685 Whether ``matplotlib.pyplot.show()`` is to be called, and the output is 

1686 to be shown immediately. Defaults to `True`. 

1687 ax : ``matplotlib.axes.Axes`` instance or `None` 

1688 If `ax` is not `None`, no new figure is created, but the plot is done 

1689 within the existing Axes `ax`. in this case, `file`, `show`, `dpi` 

1690 and `fig_size` are ignored. 

1691 pos_transform : function or `None` 

1692 Transformation to be applied to the site position. 

1693 background : matplotlib color spec 

1694 Areas without sites are filled with this color. 

1695 

1696 Returns 

1697 ------- 

1698 fig : matplotlib figure 

1699 A figure with the output if `ax` is not set, else None. 

1700 

1701 Notes 

1702 ----- 

1703 - When plotting a system on a square lattice and `method` is "nearest", it 

1704 makes sense to set `oversampling` to ``1``. Then, each site will 

1705 correspond to exactly one pixel. 

1706 

1707 See Also 

1708 -------- 

1709 kwant.plotter.density 

1710 """ 

1711 

1712 if not (_p.mpl_available or _p.plotly_available): 1712 ↛ 1713line 1712 didn't jump to line 1713, because the condition on line 1712 was never true

1713 raise RuntimeError("matplotlib was not found, but is required " 

1714 "for map()") 

1715 

1716 syst = sys # for naming consistency inside function bodies 

1717 sites = sys_leads_sites(syst, 0)[0] 

1718 coords = sys_leads_pos(syst, sites) 

1719 

1720 if pos_transform is not None: 

1721 coords = np.apply_along_axis(pos_transform, 1, coords) 

1722 

1723 if coords.shape[1] != 2: 

1724 raise ValueError('Only 2D systems can be plotted this way.') 

1725 

1726 if callable(value): 

1727 value = [value(site[0]) for site in sites] 

1728 else: 

1729 if not system.is_finite(syst): 

1730 raise ValueError('List of values is only allowed as input ' 

1731 'for finalized systems.') 

1732 value = np.array(value) 

1733 with _common.reraise_warnings(): 

1734 img, unmasked_data, _min, _max = mask_interpolate(coords, value, 

1735 a, method, oversampling) 

1736 

1737 # Calculate the min/max bounds for the colormap. 

1738 # User-provided values take precedence. 

1739 if _p.engine != "matplotlib": 

1740 unmasked_data = img.ravel() 

1741 else: 

1742 unmasked_data = img[~img.mask].data.flatten() 

1743 unmasked_data = unmasked_data[~np.isnan(unmasked_data)] 

1744 new_vmin, new_vmax = percentile_bound(unmasked_data, vmin, vmax) 

1745 overflow_pct = 100 * np.sum(unmasked_data > new_vmax) / len(unmasked_data) 

1746 underflow_pct = 100 * np.sum(unmasked_data < new_vmin) / len(unmasked_data) 

1747 if (vmin is None and underflow_pct) or (vmax is None and overflow_pct): 1747 ↛ 1748line 1747 didn't jump to line 1748, because the condition on line 1747 was never true

1748 msg = ( 

1749 'The plotted data contains ', 

1750 '{:.2f}% of values overflowing upper limit {:g} ' 

1751 .format(overflow_pct, new_vmax) 

1752 if overflow_pct > 0 else '', 

1753 'and ' if overflow_pct > 0 and underflow_pct > 0 else '', 

1754 '{:.2f}% of values underflowing lower limit {:g} ' 

1755 .format(underflow_pct, new_vmin) 

1756 if underflow_pct > 0 else '', 

1757 ) 

1758 warnings.warn(''.join(msg), RuntimeWarning, stacklevel=2) 

1759 vmin, vmax = new_vmin, new_vmax 

1760 

1761 if _p.engine == "matplotlib": 

1762 fig = _map_matplotlib(syst, img, colorbar, _max, _min, vmin, vmax, 

1763 overflow_pct, underflow_pct, cmap, num_lead_cells, 

1764 background, dpi, fig_size, ax, file) 

1765 elif _p.engine == "plotly": 1765 ↛ 1769line 1765 didn't jump to line 1769, because the condition on line 1765 was never false

1766 fig = _map_plotly(syst, img, colorbar, _max, _min, vmin, vmax, 

1767 overflow_pct, underflow_pct, cmap, num_lead_cells, 

1768 background) 

1769 elif _p.engine is None: 

1770 raise RuntimeError("Cannot use map() without a plotting lib installed") 

1771 else: 

1772 raise RuntimeError("map() does not support engine '{}'".format(_p.engine)) 

1773 

1774 _maybe_output_fig(fig, file=file, show=show) 

1775 

1776 return fig 

1777 

1778 

1779def _map_plotly(syst, img, colorbar, _max, _min, vmin, vmax, overflow_pct, 

1780 underflow_pct, cmap, num_lead_cells, background): 

1781 

1782 border = 0.5 * (_max - _min) / (np.asarray(img.shape) - 1) 

1783 _min -= border 

1784 _max += border 

1785 

1786 if cmap is None: 

1787 cmap = _p.kwant_red_plotly 

1788 

1789 img = img.T 

1790 contour_object = _p.plotly_graph_objs.Heatmap() 

1791 contour_object.z = img 

1792 contour_object.x = np.linspace(_min[0],_max[0],img.shape[0]) 

1793 contour_object.y = np.linspace(_min[1],_max[1],img.shape[1]) 

1794 contour_object.zsmooth = False 

1795 contour_object.connectgaps = False 

1796 cmap = _p.convert_cmap_list_mpl_plotly(cmap) 

1797 contour_object.colorscale = cmap 

1798 contour_object.zmax = vmax 

1799 contour_object.zmin = vmin 

1800 contour_object.hoverinfo = 'none' 

1801 

1802 contour_object.showscale = colorbar 

1803 

1804 fig = _p.plotly_graph_objs.Figure(data=[contour_object]) 

1805 fig.layout.plot_bgcolor = background 

1806 fig.layout.showlegend = False 

1807 

1808 if num_lead_cells: 1808 ↛ 1809line 1808 didn't jump to line 1809, because the condition on line 1808 was never true

1809 fig = _plot_plotly(syst, num_lead_cells, site_symbol='no symbol', 

1810 hop_lw=0, lead_site_symbol='s', 

1811 lead_site_size=0.501, lead_site_lw=0,lead_hop_lw=0, 

1812 lead_color='black', colorbar=False, show=False, 

1813 fig=fig, unit='pt', site_size=None, site_color=None, 

1814 site_edgecolor=None, site_lw=0, hop_color=None, 

1815 lead_site_edgecolor=None,pos_transform=None, 

1816 cmap=None, file=None) 

1817 

1818 return fig 

1819 

1820 

1821def _map_matplotlib(syst, img, colorbar, _max, _min, vmin, vmax, 

1822 overflow_pct, underflow_pct, cmap, num_lead_cells, 

1823 background, dpi, fig_size, ax, file): 

1824 

1825 border = 0.5 * (_max - _min) / (np.asarray(img.shape) - 1) 

1826 _min -= border 

1827 _max += border 

1828 if ax is None: 1828 ↛ 1832line 1828 didn't jump to line 1832, because the condition on line 1828 was never false

1829 fig = _make_figure(dpi, fig_size, use_pyplot=(file is None)) 

1830 ax = fig.add_subplot(1, 1, 1, aspect='equal') 

1831 else: 

1832 fig = None 

1833 

1834 if cmap is None: 

1835 cmap = _p.kwant_red_matplotlib 

1836 

1837 # Note that we tell imshow to show the array created by mask_interpolate 

1838 # faithfully and not to interpolate by itself another time. 

1839 image = ax.imshow(img.T, extent=(_min[0], _max[0], _min[1], _max[1]), 

1840 origin='lower', interpolation='none', cmap=cmap, 

1841 vmin=vmin, vmax=vmax) 

1842 if num_lead_cells: 1842 ↛ 1843line 1842 didn't jump to line 1843, because the condition on line 1842 was never true

1843 plot(syst, num_lead_cells, site_symbol='no symbol', hop_lw=0, 

1844 lead_site_symbol='s', lead_site_size=0.501, lead_site_lw=0, 

1845 lead_hop_lw=0, lead_color='black', colorbar=False, ax=ax) 

1846 

1847 ax.patch.set_facecolor(background) 

1848 

1849 if colorbar and fig is not None: 1849 ↛ 1860line 1849 didn't jump to line 1860, because the condition on line 1849 was never false

1850 # Make the colorbar ends pointy if we saturate the colormap 

1851 extend = 'neither' 

1852 if underflow_pct > 0 and overflow_pct > 0: 1852 ↛ 1853line 1852 didn't jump to line 1853, because the condition on line 1852 was never true

1853 extend = 'both' 

1854 elif underflow_pct > 0: 1854 ↛ 1855line 1854 didn't jump to line 1855, because the condition on line 1854 was never true

1855 extend = 'min' 

1856 elif overflow_pct > 0: 1856 ↛ 1857line 1856 didn't jump to line 1857, because the condition on line 1856 was never true

1857 extend = 'max' 

1858 fig.colorbar(image, extend=extend) 

1859 

1860 return fig 

1861 

1862 

1863@deprecate_args 

1864def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None, 

1865 fig_size=None, ax=None, *, params=None): 

1866 """Plot band structure of a translationally invariant 1D system. 

1867 

1868 Parameters 

1869 ---------- 

1870 sys : kwant.system.InfiniteSystem 

1871 A system bands of which are to be plotted. 

1872 args : tuple, defaults to empty 

1873 Positional arguments to pass to the ``hamiltonian`` method. 

1874 Deprecated in favor of 'params' (and mutually exclusive with it). 

1875 momenta : int or 1D array-like 

1876 Either a number of sampling points on the interval [-pi, pi], or an 

1877 array of points at which the band structure has to be evaluated. 

1878 file : string or file object or `None` 

1879 The output file. If `None`, output will be shown instead. If plotly is 

1880 selected as the engine, the filename has to end with a html extension. 

1881 show : bool 

1882 For matplotlib engine, whether ``matplotlib.pyplot.show()`` is to be 

1883 called, and the output is to be shown immediately. 

1884 For the plotly engine, a call to ``iplot(fig)`` is made if 

1885 show is True. 

1886 Defaults to `True` for both engines. 

1887 dpi : float 

1888 Number of pixels per inch. If not set the ``matplotlib`` default is 

1889 used. 

1890 Only for matplotlib engine. If the plotly engine is selected and 

1891 this argument is not None, then a RuntimeError will be triggered. 

1892 fig_size : tuple 

1893 Figure size `(width, height)` in inches. If not set, the default 

1894 ``matplotlib`` value is used. 

1895 Only for matplotlib engine. If the plotly engine is selected and 

1896 this argument is not None, then a RuntimeError will be triggered. 

1897 ax : ``matplotlib.axes.Axes`` instance or `None` 

1898 If `ax` is not `None`, no new figure is created, but the plot is done 

1899 within the existing Axes `ax`. in this case, `file`, `show`, `dpi` 

1900 and `fig_size` are ignored. 

1901 Only for matplotlib engine. If the plotly engine is selected and 

1902 this argument is not None, then a RuntimeError will be triggered. 

1903 params : dict, optional 

1904 Dictionary of parameter names and their values. Mutually exclusive 

1905 with 'args'. 

1906 

1907 Returns 

1908 ------- 

1909 fig : matplotlib figure or plotly Figure object 

1910 A figure with the output if `ax` is not set, else None. 

1911 

1912 Notes 

1913 ----- 

1914 See `~kwant.physics.Bands` for the calculation of dispersion without plotting. 

1915 """ 

1916 

1917 syst = sys # for naming consistency inside function bodies 

1918 

1919 if _p.plotly_available: 1919 ↛ 1924line 1919 didn't jump to line 1924, because the condition on line 1919 was never false

1920 if _p.engine == "plotly": 

1921 _check_incompatible_args_plotly(dpi, fig_size, ax) 

1922 

1923 

1924 _common.ensure_isinstance(syst, (system.InfiniteSystem, system.InfiniteVectorizedSystem)) 

1925 

1926 momenta = np.array(momenta) 

1927 if momenta.ndim != 1: 

1928 momenta = np.linspace(-np.pi, np.pi, momenta) 

1929 

1930 # expand out the contents of 'physics.Bands' to get the H(k), 

1931 # because 'spectrum' already does the diagonalisation. 

1932 ham = syst.cell_hamiltonian(args, params=params) 

1933 if not np.allclose(ham, ham.conjugate().transpose()): 1933 ↛ 1934line 1933 didn't jump to line 1934, because the condition on line 1933 was never true

1934 raise ValueError('The cell Hamiltonian is not Hermitian.') 

1935 _hop = syst.inter_cell_hopping(args, params=params) 

1936 hop = np.empty(ham.shape, dtype=complex) 

1937 hop[:, :_hop.shape[1]] = _hop 

1938 hop[:, _hop.shape[1]:] = 0 

1939 

1940 def h_k(k): 

1941 # H_k = H_0 + V e^-ik + V^\dagger e^ik 

1942 mat = hop * cmath.exp(-1j * k) 

1943 mat += mat.conjugate().transpose() + ham 

1944 return mat 

1945 

1946 return spectrum(h_k, ('k', momenta), file=file, show=show, dpi=dpi, 

1947 fig_size=fig_size, ax=ax) 

1948 

1949 

1950def spectrum(syst, x, y=None, params=None, mask=None, file=None, 

1951 show=True, dpi=None, fig_size=None, ax=None): 

1952 """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters. 

1953 

1954 This function requires either matplotlib or plotly to be installed. 

1955 The default engine uses matplotlib for plotting. 

1956 

1957 Parameters 

1958 ---------- 

1959 syst : `kwant.system.FiniteSystem` or callable 

1960 If a function, then it must take named parameters and return the 

1961 Hamiltonian as a dense matrix. 

1962 x : pair ``(name, values)`` 

1963 Parameter to ``ham`` that will be varied. Consists of the 

1964 parameter name, and a sequence of parameter values. 

1965 y : pair ``(name, values)``, optional 

1966 Used for 3D plots (same as ``x``). If provided, then the cartesian 

1967 product of the ``x`` values and these values will be used as a grid 

1968 over which to evaluate the spectrum. 

1969 params : dict, optional 

1970 The rest of the parameters to ``ham``, which will be kept constant. 

1971 mask : callable, optional 

1972 Takes the parameters specified by ``x`` and ``y`` and returns True 

1973 if the spectrum should not be calculated for the given parameter 

1974 values. 

1975 file : string or file object or `None` 

1976 The output file. If `None`, output will be shown instead. If plotly is 

1977 selected as the engine, the filename has to end with a html extension. 

1978 show : bool 

1979 For matplotlib engine, whether ``matplotlib.pyplot.show()`` is to be 

1980 called, and the output is to be shown immediately. 

1981 For the plotly engine, a call to ``iplot(fig)`` is made if 

1982 show is True. 

1983 Defaults to `True` for both engines. 

1984 dpi : float 

1985 Number of pixels per inch. If not set the ``matplotlib`` default is 

1986 used. 

1987 Only for matplotlib engine. If the plotly engine is selected and 

1988 this argument is not None, then a RuntimeError will be triggered. 

1989 fig_size : tuple 

1990 Figure size `(width, height)` in inches. If not set, the default 

1991 ``matplotlib`` value is used. 

1992 Only for matplotlib engine. If the plotly engine is selected and 

1993 this argument is not None, then a RuntimeError will be triggered. 

1994 ax : ``matplotlib.axes.Axes`` instance or `None` 

1995 If `ax` is not `None`, no new figure is created, but the plot is done 

1996 within the existing Axes `ax`. in this case, `file`, `show`, `dpi` 

1997 and `fig_size` are ignored. 

1998 Only for matplotlib engine. If the plotly engine is selected and 

1999 this argument is not None, then a RuntimeError will be triggered. 

2000 

2001 Returns 

2002 ------- 

2003 fig : matplotlib figure or plotly Figure object 

2004 """ 

2005 

2006 params = params or dict() 

2007 

2008 if _p.engine == "matplotlib": 

2009 return _spectrum_matplotlib(syst, x, y, params, mask, file, 

2010 show, dpi, fig_size, ax) 

2011 elif _p.engine == "plotly": 2011 ↛ 2014line 2011 didn't jump to line 2014, because the condition on line 2011 was never false

2012 _check_incompatible_args_plotly(dpi, fig_size, ax) 

2013 return _spectrum_plotly(syst, x, y, params, mask, file, show) 

2014 elif _p.engine is None: 

2015 raise RuntimeError("Cannot use spectrum() without a plotting lib installed") 

2016 else: 

2017 raise RuntimeError("spectrum() does not support engine '{}'".format(_p.engine)) 

2018 

2019 

2020def _generate_spectrum(syst, params, mask, x, y): 

2021 """Generates the spectrum dataset for the internal plotting 

2022 functions of spectrum(). 

2023 

2024 Parameters 

2025 ---------- 

2026 See spectrum(...) documentation. 

2027 

2028 Returns 

2029 ------- 

2030 spectrum : Numpy array 

2031 The energies of the system calculated at each coordinate. 

2032 planar : bool 

2033 True if y is None 

2034 array_values : tuple 

2035 The coordinates of x, y values of the dataset for plotting. 

2036 keys : tuple 

2037 Labels for the x and y axes. 

2038 """ 

2039 

2040 if system.is_finite(syst): 

2041 def ham(**kwargs): 

2042 return syst.hamiltonian_submatrix(params=kwargs, sparse=False) 

2043 elif callable(syst): 2043 ↛ 2046line 2043 didn't jump to line 2046, because the condition on line 2043 was never false

2044 ham = syst 

2045 else: 

2046 raise TypeError("Expected 'syst' to be a finite Kwant system " 

2047 "or a function.") 

2048 

2049 planar = y is None 

2050 keys = (x[0],) if planar else (x[0], y[0]) 

2051 array_values = (x[1],) if planar else (x[1], y[1]) 

2052 

2053 # calculate spectrum on the grid of points 

2054 spectrum = [] 

2055 bound_ham = functools.partial(ham, **params) 

2056 for point in itertools.product(*array_values): 

2057 p = dict(zip(keys, point)) 

2058 if mask and mask(**p): 

2059 spectrum.append(None) 

2060 else: 

2061 h_p = np.atleast_2d(bound_ham(**p)) 

2062 spectrum.append(np.linalg.eigvalsh(h_p)) 

2063 # massage masked grid points into a list of NaNs of the appropriate length 

2064 n_eigvals = len(next(filter(lambda s: s is not None, spectrum))) 

2065 nan_list = [np.nan] * n_eigvals 

2066 spectrum = [nan_list if s is None else s for s in spectrum] 

2067 # make into a numpy array and reshape 

2068 new_shape = [len(v) for v in array_values] + [-1] 

2069 spectrum = np.array(spectrum).reshape(new_shape) 

2070 

2071 return spectrum, planar, array_values, keys 

2072 

2073 

2074def _spectrum_plotly(syst, x, y=None, params=None, mask=None, 

2075 file=None, show=True): 

2076 """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters 

2077 using the plotly engine. 

2078 

2079 Parameters 

2080 ---------- 

2081 See spectrum(...) documentation. 

2082 

2083 Returns 

2084 ------- 

2085 fig : plotly Figure / dict 

2086 """ 

2087 

2088 spectrum, planar, array_values, keys = _generate_spectrum(syst, params, 

2089 mask, x, y) 

2090 

2091 if planar: 

2092 fig = _p.plotly_graph_objs.Figure(data=[ 

2093 _p.plotly_graph_objs.Scatter( 

2094 x=array_values[0], 

2095 y=energies, 

2096 ) for energies in spectrum.T 

2097 ]) 

2098 fig.layout.xaxis.title = keys[0] 

2099 fig.layout.yaxis.title = 'Energy' 

2100 fig.layout.showlegend = False 

2101 else: 

2102 fig = _p.plotly_graph_objs.Figure(data=[ 

2103 _p.plotly_graph_objs.Surface( 

2104 x=array_values[0], 

2105 y=array_values[1], 

2106 z=energies, 

2107 cmax=np.max(spectrum), 

2108 cmin=np.min(spectrum), 

2109 ) for energies in spectrum.T 

2110 ]) 

2111 fig.layout.scene.xaxis.title = keys[0] 

2112 fig.layout.scene.yaxis.title = keys[1] 

2113 fig.layout.scene.zaxis.title = 'Energy' 

2114 

2115 fig.layout.title = ( 

2116 ', '.join('{} = {}'.format(*kv) for kv in params.items()) 

2117 ) 

2118 

2119 _maybe_output_fig(fig, file=file, show=show) 

2120 

2121 return fig 

2122 

2123 

2124def _spectrum_matplotlib(syst, x, y=None, params=None, mask=None, file=None, 

2125 show=True, dpi=None, fig_size=None, ax=None): 

2126 """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters 

2127 using the matplotlib engine. 

2128 

2129 Parameters 

2130 ---------- 

2131 See spectrum(...) documentation. 

2132 

2133 Returns 

2134 ------- 

2135 fig : matplotlib figure 

2136 A figure with the output if `ax` is not set, else None. 

2137 """ 

2138 

2139 if y is not None and not _p.has3d: 2139 ↛ 2140line 2139 didn't jump to line 2140, because the condition on line 2139 was never true

2140 raise RuntimeError("Installed matplotlib does not support 3d plotting") 

2141 

2142 spectrum, planar, array_values, keys = _generate_spectrum(syst, params, 

2143 mask, x, y) 

2144 

2145 # set up axes 

2146 if ax is None: 

2147 fig = _make_figure(dpi, fig_size, use_pyplot=(file is None)) 

2148 if planar: 

2149 ax = fig.add_subplot(1, 1, 1) 

2150 else: 

2151 warnings.filterwarnings('ignore', 

2152 message=r'.*mouse rotation disabled.*') 

2153 ax = fig.add_subplot(1, 1, 1, projection='3d') 

2154 warnings.resetwarnings() 

2155 ax.set_xlabel(keys[0]) 

2156 if planar: 

2157 ax.set_ylabel('Energy') 

2158 else: 

2159 ax.set_ylabel(keys[1]) 

2160 ax.set_zlabel('Energy') 

2161 ax.set_title( 

2162 ', '.join( 

2163 '{} = {}'.format(key, value) 

2164 for key, value in params.items() 

2165 if not callable(value) 

2166 ) 

2167 ) 

2168 else: 

2169 fig = None 

2170 

2171 # actually do the plot 

2172 if planar: 

2173 ax.plot(array_values[0], spectrum) 

2174 else: 

2175 if not hasattr(ax, 'plot_surface'): 

2176 msg = ("When providing an axis for plotting over a 2D domain the " 

2177 "axis should be created with 'projection=\"3d\"") 

2178 raise TypeError(msg) 

2179 # plot_surface cannot directly handle rank-3 values, so we 

2180 # explicitly loop over the last axis 

2181 grid = np.meshgrid(*array_values) 

2182 with warnings.catch_warnings(): 

2183 warnings.filterwarnings('ignore', message='Z contains NaN values') 

2184 for i in range(spectrum.shape[-1]): 

2185 spec = spectrum[:, :, i].transpose() # row-major to x-y ordering 

2186 ax.plot_surface(*(grid + [spec]), cstride=1, rstride=1) 

2187 

2188 _maybe_output_fig(fig, file=file, show=show) 

2189 

2190 return fig 

2191 

2192 

2193# Smoothing functions used with 'interpolate_current'. 

2194 

2195# Convolution kernel with finite support: 

2196# f(r) = (1-r^2)^2 Θ(1-r^2) 

2197def _bump(r): 

2198 r[r > 1] = 1 

2199 m = 1 - r * r 

2200 return m * m 

2201 

2202 

2203# We generate the smoothing function by convolving the current 

2204# defined on a line between the two sites with 

2205# f(ρ, z) = (1 - ρ^2 - z^2)^2 Θ(1 - ρ^2 - z^2), where ρ and z are 

2206# cylindrical coords defined with respect to the hopping. 

2207# 'F' is the result of the convolution. 

2208def _smoothing(rho, z): 

2209 r = 1 - rho * rho 

2210 r[r < 0] = 0 

2211 r = np.sqrt(r) 

2212 m = np.clip(z, -r, r) 

2213 rr = r * r 

2214 rrrr = rr * rr 

2215 mm = m * m 

2216 return m * (mm * (mm/5 - (2/3) * rr) + rrrr) + (8 / 15) * rrrr * r 

2217 

2218 

2219# We need to normalize the smoothing function so that it has unit cross 

2220# section in the plane perpendicular to the hopping. This is equivalent 

2221# to normalizing the integral of 'f' over the unit hypersphere to 1. 

2222# The smoothing function goes as F(ρ) = (16/15) (1 - ρ^2)^(5/2) in the 

2223# plane perpendicular to the hopping, so the cross section is: 

2224# A_n = (16 / 15) * σ_n * ∫_0^1 ρ^(n-1) (1 - ρ^2)^(5/2) dρ 

2225# where σ_n is the surface element prefactor (2 in 2D, 2π in 3D). Rather 

2226# that calculate A_n every time, we hard code its value for 1, 2 and 3D. 

2227_smoothing_cross_sections = [16 / 15, np.pi / 3, 32 * np.pi / 105] 

2228 

2229 

2230# Determine the optimal bump function width from the absolute and 

2231# relative widths provided, and the lengths of all the hoppings in the system 

2232def _optimal_width(lens, abswidth, relwidth, bbox_size): 

2233 if abswidth is None: 

2234 if relwidth is None: 

2235 unique_lens = np.unique(lens) 

2236 longest = unique_lens[-1] 

2237 for shortest_nonzero in unique_lens: 2237 ↛ 2240line 2237 didn't jump to line 2240, because the loop on line 2237 didn't complete

2238 if shortest_nonzero / longest > 1e-3: 2238 ↛ 2237line 2238 didn't jump to line 2237, because the condition on line 2238 was never false

2239 break 

2240 width = 4 * shortest_nonzero 

2241 else: 

2242 width = relwidth * np.max(bbox_size) 

2243 else: 

2244 width = abswidth 

2245 

2246 return width 

2247 

2248 

2249# Create empty field array that covers the bounding box plus 

2250# some additional padding 

2251def _create_field(dim, bbox_size, width, n, is_current): 

2252 field_shape = np.zeros(dim + 1, int) 

2253 field_shape[dim] = dim if is_current else 1 

2254 for d in range(dim): 

2255 field_shape[d] = int(bbox_size[d] * n / width + n) 

2256 if field_shape[d] % 2: 

2257 field_shape[d] += 1 

2258 field = np.zeros(field_shape) 

2259 # padding is width / 2 

2260 return field, width / 2 

2261 

2262 

2263def density_kernel(coords): 

2264 r = np.sqrt(np.sum(coords * coords)) 

2265 return _bump(r)[..., None] 

2266 

2267 

2268def current_kernel(coords, direction, length): 

2269 z = np.dot(coords, direction) 

2270 rho = np.sqrt(np.abs(np.sum(coords * coords) - z * z)) 

2271 magn = (_smoothing(rho, z) - _smoothing(rho, z - length)) 

2272 return direction * magn[..., None] 

2273 

2274 

2275# interpolate a discrete scalar or vector field. 

2276def _interpolate_field(dim, elements, discrete_field, bbox, width, 

2277 padding, field_out): 

2278 

2279 field_shape = np.array(field_out.shape) 

2280 bbox_min, bbox_max = bbox 

2281 

2282 scale = 2 / width 

2283 

2284 # if density elements is shape (nsites, dim) 

2285 # if current elements is shape (nhops, 2, dim) 

2286 assert elements.shape[-1] == dim 

2287 is_current = len(elements.shape) == 3 

2288 if is_current: 

2289 assert elements.shape[1] == 2 

2290 dirs = elements[:, 1] - elements[:, 0] 

2291 lens = np.sqrt(np.sum(dirs * dirs, axis=-1)) 

2292 dirs /= lens[:, None] 

2293 lens = lens * scale 

2294 

2295 if is_current: 

2296 pos_offsets = elements[:, 0] # first site in hopping 

2297 kernel = current_kernel 

2298 else: 

2299 pos_offsets = elements # sites themselves 

2300 kernel = density_kernel 

2301 

2302 region = [np.linspace(bbox_min[d] - padding, 

2303 bbox_max[d] + padding, 

2304 field_shape[d]) 

2305 for d in range(dim)] 

2306 

2307 grid_density = (field_shape[:dim] - 1) / (bbox_max + 2*padding - bbox_min) 

2308 

2309 # slices for indexing 'field' and 'region' array 

2310 slices = np.empty((len(discrete_field), dim, 2), int) 

2311 if is_current: 

2312 mn = np.min(elements, 1) 

2313 mx = np.max(elements, 1) 

2314 else: 

2315 mn = mx = elements 

2316 slices[:, :, 0] = np.floor((mn - bbox_min) * grid_density) 

2317 slices[:, :, 1] = np.ceil((mx + 2*padding - bbox_min) * grid_density) 

2318 

2319 for i in range(len(discrete_field)): 

2320 

2321 if not np.diff(slices[i]).all() or not discrete_field[i]: 

2322 # Zero volume or zero field: nothing to do. 

2323 continue 

2324 

2325 field_slice = tuple([slice(*slices[i, d]) for d in range(dim)]) 

2326 

2327 # Coordinates of the grid points that are within range of the current 

2328 # hopping. 

2329 coords = np.array( 

2330 np.meshgrid( 

2331 *[region[d][field_slice[d]] for d in range(dim)], 

2332 sparse=True, indexing='ij' 

2333 ), 

2334 dtype=object 

2335 ) 

2336 

2337 # Convert "coords" into scaled distances from pos_offset 

2338 coords -= pos_offsets[i] 

2339 coords *= scale 

2340 magns = kernel(coords, dirs[i], lens[i]) if is_current else kernel(coords) 

2341 magns *= discrete_field[i] 

2342 

2343 field_out[field_slice] += magns 

2344 

2345 field_out *= scale / _smoothing_cross_sections[dim - 1] 

2346 

2347 

2348def interpolate_current(syst, current, relwidth=None, abswidth=None, n=9): 

2349 """Interpolate currents in a system onto a regular grid. 

2350 

2351 The system graph together with current intensities defines a "discrete" 

2352 current density field where the current density is non-zero only on the 

2353 straight lines that connect sites that are coupled by a hopping term. 

2354 

2355 To make this vector field easier to visualize and interpret at different 

2356 length scales, it is smoothed by convoluting it with the bell-shaped bump 

2357 function ``f(r) = max(1 - (2*r / width)**2, 0)**2``. The bump width is 

2358 determined by the `relwidth` and `abswidth` parameters. 

2359 

2360 This routine samples the smoothed field on a regular (square or cubic) 

2361 grid. 

2362 

2363 Parameters 

2364 ---------- 

2365 syst : A finalized system 

2366 The system on which we are going to calculate the field. 

2367 current : '1D array of float' 

2368 Must contain the intensity on each hoppings in the same order that they 

2369 appear in syst.graph. 

2370 relwidth : float or `None` 

2371 Relative width of the bumps used to generate the field, as a fraction 

2372 of the length of the longest side of the bounding box. This argument 

2373 is only used if `abswidth` is not given. 

2374 abswidth : float or `None` 

2375 Absolute width of the bumps used to generate the field. Takes 

2376 precedence over `relwidth`. If neither is given, the bump width is set 

2377 to four times the length of the shortest hopping. 

2378 n : int 

2379 Number of points the grid must have over the width of the bump. 

2380 

2381 Returns 

2382 ------- 

2383 field : n-d arraylike of float 

2384 n-d array of n-d vectors. 

2385 box : sequence of 2-sequences of float 

2386 the extents of `field`: ((x0, x1), (y0, y1), ...) 

2387 

2388 """ 

2389 if not builder.is_finite_system(syst): 

2390 raise TypeError("The system needs to be finalized.") 

2391 

2392 if len(current) != syst.graph.num_edges: 2392 ↛ 2393line 2392 didn't jump to line 2393, because the condition on line 2392 was never true

2393 raise ValueError("Current and hoppings arrays do not have the same" 

2394 " length.") 

2395 

2396 # hops: hoppings (pairs of points) 

2397 dim = len(syst.sites[0].pos) 

2398 hops = np.empty((syst.graph.num_edges // 2, 2, dim)) 

2399 # Take the average of the current flowing each way along the hoppings 

2400 current_one_way = np.empty(syst.graph.num_edges // 2) 

2401 seen_hoppings = dict() 

2402 kprime = 0 

2403 for k, (i, j) in enumerate(syst.graph): 

2404 if (j, i) in seen_hoppings: 

2405 current_one_way[seen_hoppings[j, i]] -= current[k] 

2406 else: 

2407 current_one_way[kprime] = current[k] 

2408 hops[kprime][0] = syst.sites[j].pos 

2409 hops[kprime][1] = syst.sites[i].pos 

2410 seen_hoppings[i, j] = kprime 

2411 kprime += 1 

2412 current = current_one_way / 2 

2413 

2414 min_hops = np.min(hops, 1) 

2415 max_hops = np.max(hops, 1) 

2416 bbox_min = np.min(min_hops, 0) 

2417 bbox_max = np.max(max_hops, 0) 

2418 bbox_size = bbox_max - bbox_min 

2419 

2420 # lens: scaled lengths of hoppings 

2421 # dirs: normalized directions of hoppings 

2422 dirs = hops[:, 1] - hops[:, 0] 

2423 lens = np.sqrt(np.sum(dirs * dirs, -1)) 

2424 dirs /= lens[:, None] 

2425 width = _optimal_width(lens, abswidth, relwidth, bbox_size) 

2426 

2427 

2428 field, padding = _create_field(dim, bbox_size, width, n, is_current=True) 

2429 boundaries = tuple((bbox_min[d] - padding, bbox_max[d] + padding) 

2430 for d in range(dim)) 

2431 _interpolate_field(dim, hops, current, 

2432 (bbox_min, bbox_max), width, padding, field) 

2433 

2434 return field, boundaries 

2435 

2436 

2437def interpolate_density(syst, density, relwidth=None, abswidth=None, n=9, 

2438 mask=True): 

2439 """Interpolate density in a system onto a regular grid. 

2440 

2441 The system sites together with a scalar for each site defines a "discrete" 

2442 density field where the density is non-zero only at the site positions. 

2443 

2444 To make this vector field easier to visualize and interpret at different 

2445 length scales, it is smoothed by convoluting it with the bell-shaped bump 

2446 function ``f(r) = max(1 - (2*r / width)**2, 0)**2``. The bump width is 

2447 determined by the `relwidth` and `abswidth` parameters. 

2448 

2449 This routine samples the smoothed field on a regular (square or cubic) 

2450 grid. 

2451 

2452 Parameters 

2453 ---------- 

2454 syst : A finalized system 

2455 The system on which we are going to calculate the field. 

2456 density : 1D array of float 

2457 Must contain the intensity on each site in the same order that they 

2458 appear in syst.sites. 

2459 relwidth : float, optional 

2460 Relative width of the bumps used to smooth the field, as a fraction 

2461 of the length of the longest side of the bounding box. This argument 

2462 is only used if ``abswidth`` is not given. 

2463 abswidth : float, optional 

2464 Absolute width of the bumps used to smooth the field. Takes 

2465 precedence over ``relwidth``. If neither is given, the bump width is set 

2466 to four times the length of the shortest hopping. 

2467 n : int 

2468 Number of points the grid must have over the width of the bump. 

2469 mask : Bool 

2470 If True, this function returns a masked array that masks positions that 

2471 are too far away from any sites. This is useful for showing an approximate 

2472 outline of the system when the field is plotted. 

2473 

2474 Returns 

2475 ------- 

2476 field : n-d arraylike of float 

2477 n-d array of n-d vectors. 

2478 box : sequence of 2-sequences of float 

2479 the extents of ``field``: ((x0, x1), (y0, y1), ...) 

2480 

2481 """ 

2482 if not builder.is_finite_system(syst): 

2483 raise TypeError("The system needs to be finalized.") 

2484 

2485 if len(density) != len(syst.sites): 2485 ↛ 2486line 2485 didn't jump to line 2486, because the condition on line 2485 was never true

2486 raise ValueError("Density and sites arrays do not have the same" 

2487 " length.") 

2488 

2489 dim = len(syst.sites[0].pos) 

2490 sites = np.array([s.pos for s in syst.sites]) 

2491 

2492 bbox_min = np.min(sites, axis=0) 

2493 bbox_max = np.max(sites, axis=0) 

2494 bbox_size = bbox_max - bbox_min 

2495 

2496 # Determine the optimal width for the bump function 

2497 dirs = np.array([syst.sites[i].pos - syst.sites[j].pos 

2498 for i, j in syst.graph]) 

2499 lens = np.sqrt(np.sum(dirs * dirs, -1)) 

2500 width = _optimal_width(lens, abswidth, relwidth, bbox_size) 

2501 

2502 field, padding = _create_field(dim, bbox_size, width, n, is_current=False) 

2503 boundaries = tuple((bbox_min[d] - padding, bbox_max[d] + padding) 

2504 for d in range(dim)) 

2505 _interpolate_field(dim, sites, density, 

2506 (bbox_min, bbox_max), width, padding, field) 

2507 

2508 if mask: 2508 ↛ 2516line 2508 didn't jump to line 2516, because the condition on line 2508 was never false

2509 # Field is zero when we are > 0.5*width from any site (as bump has 

2510 # finite support), so we mask positions a little further than this. 

2511 field = _mask(field, 

2512 box=boundaries, 

2513 coords=np.array([s.pos for s in syst.sites]), 

2514 cutoff=0.6*width) 

2515 

2516 return field, boundaries 

2517 

2518 

2519def _gamma_compress(linear): 

2520 """Compress linear sRGB into sRGB.""" 

2521 if linear <= 0.0031308: 

2522 return 12.92 * linear 

2523 else: 

2524 a = 0.055 

2525 return (1 + a) * linear ** (1 / 2.4) - a 

2526 

2527_gamma_compress = np.vectorize(_gamma_compress, otypes=[float]) 

2528 

2529 

2530def _gamma_expand(corrected): 

2531 """Expand sRGB into linear sRGB.""" 

2532 if corrected <= 0.04045: 

2533 return corrected / 12.92 

2534 else: 

2535 a = 0.055 

2536 return ((corrected + a) / (1 + a))**2.4 

2537 

2538_gamma_expand = np.vectorize(_gamma_expand, otypes=[float]) 

2539 

2540 

2541def _linear_cmap(a, b): 

2542 """Make a colormap that linearly interpolates between the colors a and b.""" 

2543 a = _p.matplotlib.colors.colorConverter.to_rgb(a) 

2544 b = _p.matplotlib.colors.colorConverter.to_rgb(b) 

2545 a_linear = _gamma_expand(a) 

2546 b_linear = _gamma_expand(b) 

2547 color_diff = a_linear - b_linear 

2548 palette = (np.linspace(0, 1, 256).reshape((-1, 1)) 

2549 * color_diff.reshape((1, -1))) 

2550 palette += b_linear 

2551 palette = _gamma_compress(palette) 

2552 return _p.matplotlib.colors.ListedColormap(palette) 

2553 

2554 

2555def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k', 

2556 max_linewidth=3, min_linewidth=1, density=2/9, 

2557 colorbar=True, file=None, 

2558 show=True, dpi=None, fig_size=None, ax=None, 

2559 vmax=None): 

2560 if _p.engine == "matplotlib": 2560 ↛ 2564line 2560 didn't jump to line 2564, because the condition on line 2560 was never false

2561 fig = _streamplot_matplotlib(field, box, cmap, bgcolor, linecolor, 

2562 max_linewidth, min_linewidth, density, colorbar, file, 

2563 show, dpi, fig_size, ax, vmax) 

2564 elif _p.engine == "plotly": 

2565 _check_incompatible_args_plotly(dpi, fig_size, ax) 

2566 fig = _streamplot_plotly(field, box, cmap, bgcolor, linecolor, 

2567 max_linewidth, min_linewidth, density, 

2568 colorbar, file, show, vmax) 

2569 elif _p.engine is None: 

2570 raise RuntimeError("Cannot use streamplot() without a plotting lib installed") 

2571 else: 

2572 raise RuntimeError("streamplot() does not support engine '{}'".format(_p.engine)) 

2573 _maybe_output_fig(fig, file=file, show=show) 

2574 

2575 

2576def _streamplot_plotly(field, box, cmap, bgcolor, linecolor, 

2577 max_linewidth, min_linewidth, density, 

2578 colorbar, file, show, vmax): 

2579 raise RuntimeError("Streamplot() for plotly engine not implemented yet due to bug from plotly") 

2580 

2581 

2582def _streamplot_matplotlib(field, box, cmap, bgcolor, linecolor, 

2583 max_linewidth, min_linewidth, density, colorbar, file, 

2584 show, dpi, fig_size, ax, vmax): 

2585 """Draw streamlines of a flow field in Kwant style 

2586 

2587 Solid colored streamlines are drawn, superimposed on a color plot of 

2588 the flow speed that may be disabled by setting `bgcolor`. The width 

2589 of the streamlines is proportional to the flow speed. Lines that 

2590 would be thinner than `min_linewidth` are blended in a perceptually 

2591 correct way into the background color in order to create the 

2592 illusion of arbitrarily thin lines. (This is done because some plot 

2593 engines like PDF do not support lines of arbitrarily thin width.) 

2594 

2595 Internally, this routine uses matplotlib's streamplot. 

2596 

2597 Parameters 

2598 ---------- 

2599 field : 3d arraylike of float 

2600 2d array of 2d vectors. 

2601 box : 2-sequence of 2-sequences of float 

2602 the extents of `field`: ((x0, x1), (y0, y1)) 

2603 cmap : colormap, optional 

2604 Colormap for the background color plot. When not set the colormap 

2605 "kwant_red" is used by default, unless `bgcolor` is set. 

2606 bgcolor : color definition, optional 

2607 The solid color of the background. Mutually exclusive with `cmap`. 

2608 linecolor : color definition 

2609 Color of the flow lines. 

2610 max_linewidth : float 

2611 Width of lines at maximum flow speed. 

2612 min_linewidth : float 

2613 Minimum width of lines before blending into the background color begins. 

2614 density : float 

2615 Number of flow lines per point of the field. The default value 

2616 of 2/9 is chosen to show two lines per default width of the 

2617 interpolation bump of `~kwant.plotter.interpolate_current`. 

2618 colorbar : bool 

2619 Whether to show a colorbar if a colormap is used. Ignored if `ax` is 

2620 provided. 

2621 file : string or file object or `None` 

2622 The output file. If `None`, output will be shown instead. 

2623 show : bool 

2624 Whether ``matplotlib.pyplot.show()`` is to be called, and the output is 

2625 to be shown immediately. Defaults to `True`. 

2626 dpi : float or `None` 

2627 Number of pixels per inch. If not set the ``matplotlib`` default is 

2628 used. 

2629 fig_size : tuple or `None` 

2630 Figure size `(width, height)` in inches. If not set, the default 

2631 ``matplotlib`` value is used. 

2632 ax : ``matplotlib.axes.Axes`` instance or `None` 

2633 If `ax` is not `None`, no new figure is created, but the plot is done 

2634 within the existing Axes `ax`. in this case, `file`, `show`, `dpi` 

2635 and `fig_size` are ignored. 

2636 vmax : float or `None` 

2637 The upper saturation limit for the colormap; flows higher than 

2638 this will saturate. Note that there is no corresponding vmin 

2639 option, vmin being fixed at zero. 

2640 

2641 Returns 

2642 ------- 

2643 fig : matplotlib figure 

2644 A figure with the output if `ax` is not set, else None. 

2645 """ 

2646 

2647 # Matplotlib's "density" is in units of 30 streamlines... 

2648 density *= 1 / 30 * ta.array(field.shape[:2], int) 

2649 

2650 # Matplotlib plots images like matrices: image[y, x]. We use the opposite 

2651 # convention: image[x, y]. Hence, it is necessary to transpose. 

2652 field = field.transpose(1, 0, 2) 

2653 

2654 if field.shape[-1] != 2 or field.ndim != 3: 2654 ↛ 2655line 2654 didn't jump to line 2655, because the condition on line 2654 was never true

2655 raise ValueError("Only 2D field can be plotted.") 

2656 

2657 if bgcolor is None: 2657 ↛ 2662line 2657 didn't jump to line 2662, because the condition on line 2657 was never false

2658 if cmap is None: 2658 ↛ 2660line 2658 didn't jump to line 2660, because the condition on line 2658 was never false

2659 cmap = _p.kwant_red_matplotlib 

2660 cmap = _p.matplotlib.cm.get_cmap(cmap) 

2661 bgcolor = cmap(0)[:3] 

2662 elif cmap is not None: 

2663 raise ValueError("The parameters 'cmap' and 'bgcolor' are " 

2664 "mutually exclusive.") 

2665 

2666 if ax is None: 

2667 fig = _make_figure(dpi, fig_size, use_pyplot=(file is None)) 

2668 ax = fig.add_subplot(1, 1, 1, aspect='equal') 

2669 else: 

2670 fig = None 

2671 

2672 X = np.linspace(*box[0], num=field.shape[1]) 

2673 Y = np.linspace(*box[1], num=field.shape[0]) 

2674 

2675 speed = np.linalg.norm(field, axis=-1) 

2676 if vmax is None: 2676 ↛ 2679line 2676 didn't jump to line 2679, because the condition on line 2676 was never false

2677 vmax = np.max(speed) or 1 

2678 

2679 if cmap is None: 2679 ↛ 2680line 2679 didn't jump to line 2680, because the condition on line 2679 was never true

2680 ax.set_axis_bgcolor(bgcolor) 

2681 else: 

2682 image = ax.imshow(speed, cmap=cmap, 

2683 interpolation='bicubic', 

2684 extent=[e for c in box for e in c], 

2685 origin='lower', vmin=0, vmax=vmax) 

2686 

2687 linewidth = max_linewidth / vmax * speed 

2688 color = linewidth / min_linewidth 

2689 thin = linewidth < min_linewidth 

2690 linewidth[thin] = min_linewidth 

2691 color[~ thin] = 1 

2692 

2693 line_cmap = _linear_cmap(linecolor, bgcolor) 

2694 

2695 ax.streamplot(X, Y, field[:,:,0], field[:,:,1], 

2696 density=density, linewidth=linewidth, 

2697 color=color, cmap=line_cmap, arrowstyle='->', 

2698 norm=_p.matplotlib.colors.Normalize(0, 1)) 

2699 

2700 ax.set_xlim(*box[0]) 

2701 ax.set_ylim(*box[1]) 

2702 

2703 if colorbar and cmap and fig is not None: 

2704 fig.colorbar(image) 

2705 

2706 _maybe_output_fig(fig, file=file, show=show) 

2707 

2708 return fig 

2709 

2710 

2711def scalarplot(field, box, 

2712 cmap=None, colorbar=True, file=None, show=True, 

2713 dpi=None, fig_size=None, ax=None, vmin=None, vmax=None, 

2714 background='#e0e0e0'): 

2715 """Draw a scalar field in Kwant style 

2716 

2717 Internally, this routine uses matplotlib's imshow. 

2718 

2719 Parameters 

2720 ---------- 

2721 field : 2d arraylike of float 

2722 2d scalar field to plot. 

2723 box : pair of pair of float 

2724 the realspace extents of ``field``: ((x0, x1), (y0, y1)) 

2725 cmap : colormap, optional 

2726 Colormap for the background color plot. When not set the colormap 

2727 "kwant_red" is used by default. 

2728 colorbar : bool, default: True 

2729 Whether to show a colorbar if a colormap is used. Ignored if `ax` is 

2730 provided. 

2731 file : string or file object, optional 

2732 The output file. If not provided, output will be shown instead. 

2733 show : bool, default: True 

2734 Whether ``matplotlib.pyplot.show()`` is to be called, and the output is 

2735 to be shown immediately. 

2736 dpi : float, optional 

2737 Number of pixels per inch. If not set the ``matplotlib`` default is 

2738 used. 

2739 fig_size : tuple, optional 

2740 Figure size ``(width, height)`` in inches. If not set, the default 

2741 ``matplotlib`` value is used. 

2742 ax : ``matplotlib.axes.Axes`` instance, optional 

2743 If ``ax`` is provided, no new figure is created, but the plot is done 

2744 within the existing Axes ``ax``. in this case, ``file``, ``show``, 

2745 ``dpi`` and ``fig_size`` are ignored. 

2746 vmin, vmax : float, optional 

2747 The lower/upper saturation limit for the colormap. 

2748 background : matplotlib color spec 

2749 Areas outside the system are filled with this color. 

2750 

2751 Returns 

2752 ------- 

2753 fig : matplotlib figure 

2754 A figure with the output if ``ax`` is not set, else None. 

2755 """ 

2756 

2757 # Matplotlib plots images like matrices: image[y, x]. We use the opposite 

2758 # convention: image[x, y]. Hence, it is necessary to transpose. 

2759 # Also squeeze out the last axis as it is just a scalar field 

2760 

2761 field = field.squeeze(axis=-1).transpose() 

2762 

2763 if field.ndim != 2: 

2764 raise ValueError("Only 2D field can be plotted.") 

2765 

2766 if vmin is None: 

2767 vmin = np.min(field) 

2768 if vmax is None: 

2769 vmax = np.max(field) 

2770 

2771 if _p.engine == "matplotlib": 

2772 fig = _scalarplot_matplotlib(field, box, cmap, colorbar, 

2773 file, show, dpi, fig_size, ax, 

2774 vmin, vmax, background) 

2775 elif _p.engine == "plotly": 

2776 _check_incompatible_args_plotly(dpi, fig_size, ax) 

2777 fig = _scalarplot_plotly(field, box, cmap, colorbar, file, 

2778 show, vmin, vmax, background) 

2779 elif _p.engine is None: 

2780 raise RuntimeError("Cannot use scalarplot() without a plotting lib installed") 

2781 else: 

2782 raise RuntimeError("scalarplot() does not support engine '{}'".format(_p.engine)) 

2783 _maybe_output_fig(fig, file=file, show=show) 

2784 

2785 return fig 

2786 

2787 

2788def _scalarplot_plotly(field, box, cmap, colorbar, file, 

2789 show, vmin, vmax, background): 

2790 

2791 if cmap is None: 

2792 cmap = _p.kwant_red_plotly 

2793 

2794 contour_object = _p.plotly_graph_objs.Heatmap() 

2795 contour_object.z = field 

2796 contour_object.x = np.linspace(*box[0],field.shape[0]) 

2797 contour_object.y = np.linspace(*box[1],field.shape[1]) 

2798 contour_object.zsmooth = 'best' 

2799 contour_object.colorscale = cmap 

2800 contour_object.zmax = vmax 

2801 contour_object.zmin = vmin 

2802 

2803 contour_object.showscale = colorbar 

2804 

2805 fig = _p.plotly_graph_objs.Figure(data=[contour_object]) 

2806 fig.layout.plot_bgcolor = background 

2807 

2808 return fig 

2809 

2810 

2811def _scalarplot_matplotlib(field, box, cmap, colorbar, file, show, dpi, 

2812 fig_size, ax, vmin, vmax, background): 

2813 

2814 if cmap is None: 

2815 cmap = _p.kwant_red_matplotlib 

2816 cmap = _p.matplotlib.cm.get_cmap(cmap) 

2817 

2818 if ax is None: 

2819 fig = _make_figure(dpi, fig_size, use_pyplot=(file is None)) 

2820 ax = fig.add_subplot(1, 1, 1, aspect='equal') 

2821 else: 

2822 fig = None 

2823 

2824 image = ax.imshow(field, cmap=cmap, 

2825 interpolation='bicubic', 

2826 extent=[e for c in box for e in c], 

2827 origin='lower', vmin=vmin, vmax=vmax) 

2828 

2829 ax.set_xlim(*box[0]) 

2830 ax.set_ylim(*box[1]) 

2831 ax.patch.set_facecolor(background) 

2832 

2833 if colorbar and cmap and fig is not None: 

2834 fig.colorbar(image) 

2835 

2836 return fig 

2837 

2838 

2839def current(syst, current, relwidth=0.05, **kwargs): 

2840 """Show an interpolated current defined for the hoppings of a system. 

2841 

2842 The system graph together with current intensities defines a "discrete" 

2843 current density field where the current density is non-zero only on the 

2844 straight lines that connect sites that are coupled by a hopping term. 

2845 

2846 To make this scalar field easier to visualize and interpret at different 

2847 length scales, it is smoothed by convoluting it with the bell-shaped bump 

2848 function ``f(r) = max(1 - (2*r / width)**2, 0)**2``. The bump width is 

2849 determined by the ``relwidth`` parameter. 

2850 

2851 This routine samples the smoothed field on a regular (square or cubic) grid 

2852 and displays it using an enhanced variant of matplotlib's streamplot. 

2853 

2854 This is a convenience function that is equivalent to 

2855 ``streamplot(*interpolate_current(syst, current, relwidth), **kwargs)``. 

2856 The longer form makes it possible to tweak additional options of 

2857 `~kwant.plotter.interpolate_current`. 

2858 

2859 Parameters 

2860 ---------- 

2861 syst : `kwant.system.FiniteSystem` 

2862 The system for which to plot the ``current``. 

2863 current : sequence of float 

2864 Sequence of values defining currents on each hopping of the system. 

2865 Ordered in the same way as ``syst.graph``. This typically will be 

2866 the result of evaluating a `~kwant.operator.Current` operator. 

2867 relwidth : float or `None` 

2868 Relative width of the bumps used to smooth the field, as a fraction 

2869 of the length of the longest side of the bounding box. 

2870 **kwargs : various 

2871 Keyword args to be passed verbatim to `kwant.plotter.streamplot`. 

2872 

2873 Returns 

2874 ------- 

2875 fig : matplotlib figure 

2876 A figure with the output if ``ax`` is not set, else None. 

2877 

2878 See Also 

2879 -------- 

2880 kwant.plotter.density 

2881 """ 

2882 with _common.reraise_warnings(4): 

2883 return streamplot(*interpolate_current(syst, current, relwidth), 

2884 **kwargs) 

2885 

2886 

2887def _mask(field, box, coords, cutoff): 

2888 tree = spatial.cKDTree(coords) 

2889 

2890 # Build the mask initially as a 2D array 

2891 dims = tuple(slice(boxmin, boxmax, 1j * shape) 

2892 for (boxmin, boxmax), shape in zip(box, field.shape)) 

2893 mask = np.mgrid[dims].reshape(len(box), -1).T 

2894 

2895 mask = tree.query(mask, distance_upper_bound=cutoff)[0] == np.inf 

2896 return np.ma.masked_array(field, mask) 

2897 

2898 

2899def density(syst, density, relwidth=0.05, **kwargs): 

2900 """Show an interpolated density defined on the sites of a system. 

2901 

2902 The system sites, together with a scalar per site defines a "discrete" 

2903 density field that is non-zero only on the sites. 

2904 

2905 To make this scalar field easier to visualize and interpret at different 

2906 length scales, it is smoothed by convoluting it with the bell-shaped bump 

2907 function ``f(r) = max(1 - (2*r / width)**2, 0)**2``. The bump width is 

2908 determined by the ``relwidth`` parameter. 

2909 

2910 This routine samples the smoothed field on a regular (square or cubic) grid 

2911 and displays it using matplotlib's imshow. 

2912 

2913 This function is similar to `~kwant.plotter.map`, but generally gives more 

2914 appealing visual results when used on systems with many sites. If you want 

2915 site-level resolution you may be better off using `~kwant.plotter.map`. 

2916 

2917 This is a convenience function that is equivalent to 

2918 ``scalarplot(*interpolate_density(syst, density, relwidth), **kwargs)``. 

2919 The longer form makes it possible to tweak additional options of 

2920 `~kwant.plotter.interpolate_density`. 

2921 

2922 Parameters 

2923 ---------- 

2924 syst : `kwant.system.FiniteSystem` 

2925 The system for which to plot ``density``. 

2926 density : sequence of float 

2927 Sequence of values defining density on each site of the system. 

2928 Ordered in the same way as ``syst.sites``. This typically will be 

2929 the result of evaluating a `~kwant.operator.Density` operator. 

2930 relwidth : float or `None` 

2931 Relative width of the bumps used to smooth the field, as a fraction 

2932 of the length of the longest side of the bounding box. 

2933 **kwargs : various 

2934 Keyword args to be passed verbatim to `~kwant.plotter.scalarplot`. 

2935 

2936 Returns 

2937 ------- 

2938 fig : matplotlib figure 

2939 A figure with the output if ``ax`` is not set, else None. 

2940 

2941 See Also 

2942 -------- 

2943 kwant.plotter.current 

2944 kwant.plotter.map 

2945 """ 

2946 with _common.reraise_warnings(4): 

2947 return scalarplot(*interpolate_density(syst, density, relwidth), 

2948 **kwargs) 

2949 

2950 

2951# TODO (Anton): Fix plotting of parts of the system using color = np.nan. 

2952# Not plotting sites currently works, not plotting hoppings does not. 

2953# TODO (Anton): Allow a more flexible treatment of position than pos_transform 

2954# (an interface for user-defined pos).