Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2# Copyright 2011-2019 Kwant authors. 

3# 

4# This file is part of Kwant. It is subject to the license terms in the file 

5# LICENSE.rst found in the top-level directory of this distribution and at 

6# https://kwant-project.org/license. A list of Kwant authors can be found in 

7# the file AUTHORS.rst at the top-level directory of this distribution and at 

8# https://kwant-project.org/authors. 

9 

10# This module is imported by plotter.py. It contains all the expensive imports 

11# that we want to remove from plotter.py 

12 

13# All matplotlib imports must be isolated in a try, because even without 

14# matplotlib iterators remain useful. Further, mpl_toolkits used for 3D 

15# plotting are also imported separately, to ensure that 2D plotting works even 

16# if 3D does not. 

17 

18import warnings 

19from math import sqrt, pi 

20import numpy as np 

21from enum import Enum 

22 

23 

24try: 

25 __IPYTHON__ 

26 is_ipython_kernel = True 

27except NameError: 

28 is_ipython_kernel = False 

29 

30global mpl_available 

31global plotly_available 

32mpl_available = False 

33plotly_available = False 

34 

35try: 

36 import matplotlib 

37 import matplotlib.colors 

38 import matplotlib.cm 

39 from matplotlib.figure import Figure 

40 from matplotlib import collections 

41 from . import _colormaps 

42 from matplotlib.colors import ListedColormap 

43 mpl_available = True 

44 kwant_red_matplotlib = ListedColormap(_colormaps.kwant_red, 

45 name="kwant red") 

46 try: 

47 from mpl_toolkits import mplot3d 

48 has3d = True 

49 except ImportError: 

50 warnings.warn("3D plotting not available.", RuntimeWarning) 

51 has3d = False 

52except ImportError: 

53 warnings.warn("matplotlib is not available, if other engines are " 

54 "unavailable, only iterator-providing functions will work", 

55 RuntimeWarning) 

56 

57 

58try: 

59 import plotly.offline as plotly_module 

60 import plotly.graph_objs as plotly_graph_objs 

61 init_notebook_mode_set = False 

62 from . import _colormaps 

63 plotly_available = True 

64 

65 _cmap_plotly = 255 * _colormaps.kwant_red 

66 _cmap_levels = np.linspace(0, 1, len(_cmap_plotly)) 

67 kwant_red_plotly = [(level, 'rgb({},{},{})'.format(*rgb)) 

68 for level, rgb in zip(_cmap_levels, _cmap_plotly)] 

69except ImportError: 

70 warnings.warn("plotly is not available, if other engines are unavailable," 

71 " only iterator-providing functions will work", 

72 RuntimeWarning) 

73 

74engines = [] 

75engine = None 

76 

77if plotly_available: 77 ↛ 81line 77 didn't jump to line 81, because the condition on line 77 was never false

78 engines.append("plotly") 

79 engine = "plotly" 

80 

81if mpl_available: 81 ↛ 85line 81 didn't jump to line 85, because the condition on line 81 was never false

82 engines.append("matplotlib") 

83 engine = "matplotlib" 

84 

85engines = frozenset(engines) 

86 

87 

88# Collections that allow for symbols and linewiths to be given in data space 

89# (not for general use, only implement what's needed for plotter) 

90def isarray(var): 

91 if hasattr(var, '__getitem__') and not isinstance(var, str): 

92 return True 

93 else: 

94 return False 

95 

96 

97def nparray_if_array(var): 

98 return np.asarray(var) if isarray(var) else var 

99 

100 

101if plotly_available: 101 ↛ 205line 101 didn't jump to line 205, because the condition on line 101 was never false

102 

103 # The converter_map and converter_map_3d converts the common marker symbols 

104 # of matplotlib to the symbols of plotly 

105 converter_map = { 

106 "o": 0, 

107 "v": 6, 

108 "^": 5, 

109 "<": 7, 

110 ">": 8, 

111 "s": 1, 

112 "+": 3, 

113 "x": 4, 

114 "*": 17, 

115 "d": 2, 

116 "h": 14, 

117 "no symbol": -1 

118 } 

119 

120 converter_map_3d = { 

121 "o": "circle", 

122 "s": "square", 

123 "+": "cross", 

124 "x": "x", 

125 "d": "diamond", 

126 } 

127 

128 def error_string(symbol_input, supported): 

129 return 'Input symbol/s \'{}\' not supported. Only the following characters are supported: {}'.format(symbol_input, supported) 

130 

131 

132 def convert_symbol_mpl_plotly(mpl_symbol): 

133 if isarray(mpl_symbol): 133 ↛ 134line 133 didn't jump to line 134, because the condition on line 133 was never true

134 try: 

135 converted_symbol = [converter_map.get(i) for i in mpl_symbol] 

136 except KeyError: 

137 raise RuntimeError( error_string(mpl_symbol, list(converter_map)) ) 

138 else: 

139 try: 

140 converted_symbol = converter_map.get(mpl_symbol) 

141 except KeyError: 

142 raise RuntimeError( error_string(mpl_symbol, list(converter_map)) ) 

143 return converted_symbol 

144 

145 

146 def convert_symbol_mpl_plotly_3d(mpl_symbol): 

147 if isarray(mpl_symbol): 147 ↛ 148line 147 didn't jump to line 148, because the condition on line 147 was never true

148 try: 

149 converted_symbol = [converter_map_3d.get(i) for i in mpl_symbol] 

150 except KeyError: 

151 raise RuntimeError( error_string(mpl_symbol, list(converter_map_3d)) ) 

152 else: 

153 try: 

154 converted_symbol = converter_map_3d.get(mpl_symbol) 

155 except KeyError: 

156 raise RuntimeError( error_string(mpl_symbol, list(converter_map_3d)) ) 

157 return converted_symbol 

158 

159 

160 def convert_site_size_mpl_plotly(mpl_site_size, plotly_ref_px): 

161 # The conversion is such that we assume matplotlib's marker size is in 

162 # square points (https://matplotlib.org/devdocs/api/_as_gen/matplotlib.pyplot.scatter.html) 

163 # and we need to convert the points to pixels for plotly. 

164 # Hence, 1 pixel = (96.0)/(72.0) point 

165 return np.sqrt(mpl_site_size)*(96.0/72.0)*plotly_ref_px 

166 

167 

168 def convert_colormap_mpl_plotly(r, g, b, a): 

169 return f"rgba({255*r},{255*g},{255*b},{a})" 

170 

171 

172 def convert_cmap_list_mpl_plotly(mpl_cmap_name): 

173 if isinstance(mpl_cmap_name, str): 

174 cmap = matplotlib.cm.get_cmap(mpl_cmap_name) 

175 cmap_plotly_linear = [ 

176 (level, convert_colormap_mpl_plotly(*cmap(level))) 

177 for level in np.linspace(0, 1, cmap.N) 

178 ] 

179 else: 

180 assert(isinstance(mpl_cmap_name, list)) 

181 # Do not do any conversion if it's already a list 

182 cmap_plotly_linear = mpl_cmap_name 

183 return cmap_plotly_linear 

184 

185 

186 def convert_lead_cmap_mpl_plotly(mpl_lead_cmap_init, mpl_lead_cmap_end, 

187 N=255): 

188 r_levels = np.linspace(mpl_lead_cmap_init[0], 

189 mpl_lead_cmap_end[0], N) * 255 

190 g_levels = np.linspace(mpl_lead_cmap_init[1], 

191 mpl_lead_cmap_end[1], N) * 255 

192 b_levels = np.linspace(mpl_lead_cmap_init[2], 

193 mpl_lead_cmap_end[2], N) * 255 

194 a_levels = np.linspace(mpl_lead_cmap_init[3], 

195 mpl_lead_cmap_end[3], N) 

196 level = np.linspace(0, 1, N) 

197 cmap_plotly_linear = [(level, 'rgba({},{},{},{})'.format(*rgba)) 

198 for level, rgba in zip(level, 

199 zip(r_levels, g_levels, 

200 b_levels, a_levels 

201 ))] 

202 return cmap_plotly_linear 

203 

204 

205if mpl_available: 205 ↛ 479line 205 didn't jump to line 479, because the condition on line 205 was never false

206 class LineCollection(collections.LineCollection): 

207 def __init__(self, segments, reflen=None, **kwargs): 

208 super().__init__(segments, **kwargs) 

209 self.reflen = reflen 

210 

211 def set_linewidths(self, linewidths): 

212 self.linewidths_orig = nparray_if_array(linewidths) 

213 

214 def draw(self, renderer): 

215 if self.reflen is not None: 215 ↛ 221line 215 didn't jump to line 221, because the condition on line 215 was never false

216 # Note: only works for aspect ratio 1! 

217 # 72.0 - there is 72 points in an inch 

218 factor = (self.axes.transData.frozen().to_values()[0] * 72.0 * 

219 self.reflen / self.figure.dpi) 

220 else: 

221 factor = 1 

222 

223 super().set_linewidths(self.linewidths_orig * 

224 factor) 

225 return super().draw(renderer) 

226 

227 

228 class PathCollection(collections.PathCollection): 

229 def __init__(self, paths, sizes=None, reflen=None, **kwargs): 

230 super().__init__(paths, sizes=sizes, **kwargs) 

231 

232 self.reflen = reflen 

233 self.linewidths_orig = nparray_if_array(self.get_linewidths()) 

234 

235 self.transforms = np.array( 

236 [matplotlib.transforms.Affine2D().scale(x).get_matrix() 

237 for x in sizes]) 

238 

239 def get_transforms(self): 

240 return self.transforms 

241 

242 def get_transform(self): 

243 Affine2D = matplotlib.transforms.Affine2D 

244 if self.reflen is not None: 244 ↛ 250line 244 didn't jump to line 250, because the condition on line 244 was never false

245 # For the paths, use the data transformation but strip the 

246 # offset (will be added later with offsets) 

247 args = self.axes.transData.frozen().to_values()[:4] + (0, 0) 

248 return Affine2D().from_values(*args).scale(self.reflen) 

249 else: 

250 return Affine2D().scale(self.figure.dpi / 72.0) 

251 

252 def draw(self, renderer): 

253 if self.reflen: 253 ↛ 259line 253 didn't jump to line 259, because the condition on line 253 was never false

254 # Note: only works for aspect ratio 1! 

255 factor = (self.axes.transData.frozen().to_values()[0] / 

256 self.figure.dpi * 72.0 * self.reflen) 

257 self.set_linewidths(self.linewidths_orig * factor) 

258 

259 return collections.Collection.draw(self, renderer) 

260 

261 

262 if has3d: 262 ↛ 479line 262 didn't jump to line 479, because the condition on line 262 was never false

263 # Sorting is optional. 

264 sort3d = True 

265 

266 # Compute the projection of a 3D length into 2D data coordinates 

267 # for this we use 2 3D half-circles that are projected into 2D. 

268 # (This gives the same length as projecting the full unit sphere.) 

269 

270 phi = np.linspace(0, pi, 21) 

271 xyz = np.c_[np.cos(phi), np.sin(phi), 0 * phi].T.reshape(-1, 1, 21) 

272 

273 unit_sphere = np.block([ 

274 [xyz[0], xyz[2]], 

275 [xyz[1], xyz[0]], 

276 [xyz[2], xyz[1]], 

277 ]) 

278 

279 def projected_length(ax, length): 

280 rc = np.array([ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()]) 

281 rc = np.apply_along_axis(np.sum, 1, rc) / 2. 

282 

283 rs = unit_sphere * length + rc.reshape(-1, 1) 

284 

285 transform = mplot3d.proj3d.proj_transform 

286 rp = np.asarray(transform(*(list(rs) + [ax.get_proj()]))[:2]) 

287 rc[:2] = transform(*(list(rc) + [ax.get_proj()]))[:2] 

288 

289 coords = rp - np.repeat(rc[:2].reshape(-1, 1), len(rs[0]), axis=1) 

290 return sqrt(np.sum(coords**2, axis=0).max()) 

291 

292 

293 # Auxiliary array for calculating corners of a cube. 

294 corners = np.zeros((3, 8, 6), np.float_) 

295 corners[0, [0, 1, 2, 3], 0] = corners[0, [4, 5, 6, 7], 1] = \ 

296 corners[0, [0, 1, 4, 5], 2] = corners[0, [2, 3, 6, 7], 3] = \ 

297 corners[0, [0, 2, 4, 6], 4] = corners[0, [1, 3, 5, 7], 5] = 1.0 

298 

299 

300 class Line3DCollection(mplot3d.art3d.Line3DCollection): 

301 def __init__(self, segments, reflen=None, zorder=0, **kwargs): 

302 super().__init__(segments, **kwargs) 

303 self.reflen = reflen 

304 self.zorder3d = zorder 

305 

306 def set_linewidths(self, linewidths): 

307 self.linewidths_orig = nparray_if_array(linewidths) 

308 

309 def do_3d_projection(self, renderer): 

310 super().do_3d_projection(renderer) 

311 # The whole 3D ordering is flawed in mplot3d when several 

312 # collections are added. We just use normal zorder. Note the 

313 # "-" due to the different logic in the 3d plotting, we still 

314 # want larger zorder values to be plotted on top of smaller 

315 # ones. 

316 return -self.zorder3d 

317 

318 def draw(self, renderer): 

319 if self.reflen: 319 ↛ 330line 319 didn't jump to line 330, because the condition on line 319 was never false

320 proj_len = projected_length(self.axes, self.reflen) 

321 args = self.axes.transData.frozen().to_values() 

322 # Note: unlike in the 2D case, where we can enforce equal 

323 # aspect ratio, this (currently) does not work with 

324 # 3D plots in matplotlib. As an approximation, we 

325 # thus scale with the average of the x- and y-axis 

326 # transformation. 

327 factor = proj_len * (args[0] + 

328 args[3]) * 0.5 * 72.0 / self.figure.dpi 

329 else: 

330 factor = 1 

331 

332 super().set_linewidths( 

333 self.linewidths_orig * factor) 

334 super().draw(renderer) 

335 

336 

337 class Path3DCollection(mplot3d.art3d.Patch3DCollection): 

338 def __init__(self, paths, sizes, reflen=None, zorder=0, 

339 offsets=None, **kwargs): 

340 paths = [matplotlib.patches.PathPatch(path) for path in paths] 

341 

342 if offsets is not None: 342 ↛ 347line 342 didn't jump to line 347, because the condition on line 342 was never false

343 kwargs['offsets'] = offsets[:, :2] 

344 

345 # Workaround for issue in Matplotlib-3.4.2 before PR merged 

346 # https://github.com/matplotlib/matplotlib/pull/20416 

347 self._z_markers_idx = slice(-1) 

348 

349 super().__init__(paths, **kwargs) 

350 

351 if offsets is not None: 351 ↛ 354line 351 didn't jump to line 354, because the condition on line 351 was never false

352 self.set_3d_properties(zs=offsets[:, 2], zdir="z") 

353 

354 self.reflen = reflen 

355 self.zorder3d = zorder 

356 

357 self.paths_orig = np.array(paths, dtype='object') 

358 self.linewidths_orig = nparray_if_array(self.get_linewidths()) 

359 self.linewidths_orig2 = self.linewidths_orig 

360 self.array_orig = nparray_if_array(self.get_array()) 

361 self.facecolors_orig = nparray_if_array(self.get_facecolors()) 

362 self.edgecolors_orig = nparray_if_array(self.get_edgecolors()) 

363 

364 Affine2D = matplotlib.transforms.Affine2D 

365 self.orig_transforms = np.array( 

366 [Affine2D().scale(x).get_matrix() for x in sizes]) 

367 self.transforms = self.orig_transforms 

368 

369 def set_array(self, array): 

370 self.array_orig = nparray_if_array(array) 

371 super().set_array(array) 

372 

373 def set_color(self, colors): 

374 self.facecolors_orig = nparray_if_array(colors) 

375 self.edgecolors_orig = self.facecolors_orig 

376 super().set_color(colors) 

377 

378 def set_edgecolors(self, colors): 

379 colors = matplotlib.colors.colorConverter.to_rgba_array(colors) 

380 self.edgecolors_orig = nparray_if_array(colors) 

381 super().set_edgecolors(colors) 

382 

383 def get_transforms(self): 

384 # this is exact only for an isometric projection, for the 

385 # perspective projection used in mplot3d it's an approximation 

386 return self.transforms 

387 

388 def get_transform(self): 

389 Affine2D = matplotlib.transforms.Affine2D 

390 if self.reflen: 390 ↛ 398line 390 didn't jump to line 398, because the condition on line 390 was never false

391 proj_len = projected_length(self.axes, self.reflen) 

392 

393 # For the paths, use the data transformation but strip the 

394 # offset (will be added later with the offsets). 

395 args = self.axes.transData.frozen().to_values()[:4] + (0, 0) 

396 return Affine2D().from_values(*args).scale(proj_len) 

397 else: 

398 return Affine2D().scale(self.figure.dpi / 72.0) 

399 

400 def do_3d_projection(self, renderer): 

401 xs, ys, zs = self._offsets3d 

402 

403 # numpy complains about zero-length index arrays 

404 if len(xs) == 0: 404 ↛ 405line 404 didn't jump to line 405, because the condition on line 404 was never true

405 return -self.zorder3d 

406 

407 proj = mplot3d.proj3d.proj_transform_clip 

408 vs = np.array(proj(xs, ys, zs, renderer.M)[:3]) 

409 

410 if sort3d: 410 ↛ 450line 410 didn't jump to line 450, because the condition on line 410 was never false

411 indx = vs[2].argsort()[::-1] 

412 

413 self.set_offsets(vs[:2, indx].T) 

414 

415 if len(self.paths_orig) > 1: 415 ↛ 416line 415 didn't jump to line 416, because the condition on line 415 was never true

416 paths = np.resize(self.paths_orig, (vs.shape[1],)) 

417 self.set_paths(paths[indx]) 

418 

419 if len(self.orig_transforms) > 1: 419 ↛ 420line 419 didn't jump to line 420, because the condition on line 419 was never true

420 self.transforms = self.transforms[indx] 

421 

422 lw_orig = self.linewidths_orig 

423 if (isinstance(lw_orig, np.ndarray) and len(lw_orig) > 1): 423 ↛ 424line 423 didn't jump to line 424, because the condition on line 423 was never true

424 self.linewidths_orig2 = np.resize(lw_orig, 

425 (vs.shape[1],))[indx] 

426 

427 # Note: here array, facecolors and edgecolors are 

428 # guaranteed to be 2d numpy arrays or None. (And 

429 # array is the same length as the coordinates) 

430 

431 if self.array_orig is not None: 

432 super(Path3DCollection, 

433 self).set_array(self.array_orig[indx]) 

434 

435 if (self.facecolors_orig is not None and 

436 self.facecolors_orig.shape[0] > 1): 

437 shape = list(self.facecolors_orig.shape) 

438 shape[0] = vs.shape[1] 

439 super().set_facecolors( 

440 np.resize(self.facecolors_orig, shape)[indx]) 

441 

442 if (self.edgecolors_orig is not None and 442 ↛ 444line 442 didn't jump to line 444, because the condition on line 442 was never true

443 self.edgecolors_orig.shape[0] > 1): 

444 shape = list(self.edgecolors_orig.shape) 

445 shape[0] = vs.shape[1] 

446 super().set_edgecolors( 

447 np.resize(self.edgecolors_orig, 

448 shape)[indx]) 

449 else: 

450 self.set_offsets(vs[:2].T) 

451 

452 # the whole 3D ordering is flawed in mplot3d when several 

453 # collections are added. We just use normal zorder, but correct 

454 # by the projected z-coord of the "center of gravity", 

455 # normalized by the projected z-coord of the world coordinates. 

456 # In doing so, several Path3DCollections are plotted probably 

457 # in the right order (it's not exact) if they have the same 

458 # zorder. Still, smaller and larger integer zorders are plotted 

459 # below or on top. 

460 

461 bbox = np.asarray(self.axes.get_w_lims()) 

462 

463 proj = mplot3d.proj3d.proj_transform_clip 

464 cz = proj(*(list(np.dot(corners, bbox)) + [renderer.M]))[2] 

465 

466 return -self.zorder3d + vs[2].mean() / cz.ptp() 

467 

468 def draw(self, renderer): 

469 if self.reflen: 469 ↛ 477line 469 didn't jump to line 477, because the condition on line 469 was never false

470 proj_len = projected_length(self.axes, self.reflen) 

471 args = self.axes.transData.frozen().to_values() 

472 factor = proj_len * (args[0] + 

473 args[3]) * 0.5 * 72.0 / self.figure.dpi 

474 

475 self.set_linewidths(self.linewidths_orig2 * factor) 

476 

477 super().draw(renderer) 

478 

479if plotly_available: 479 ↛ exitline 479 didn't exit the module, because the condition on line 479 was never false

480 def matplotlib_to_plotly_cmap(cmap, pl_entries): 

481 h = 1.0/(pl_entries-1) 

482 pl_colorscale = [] 

483 

484 for k in range(pl_entries): 

485 C = map(np.uint8, np.array(cmap(k*h)[:3])*255) 

486 pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))]) 

487 

488 return pl_colorscale