Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2 

3# gms_preprocessing, spatial and spectral homogenization of satellite remote sensing data 

4# 

5# Copyright (C) 2020 Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de) 

6# 

7# This software was developed within the context of the GeoMultiSens project funded 

8# by the German Federal Ministry of Education and Research 

9# (project grant code: 01 IS 14 010 A-C). 

10# 

11# This program is free software: you can redistribute it and/or modify it under 

12# the terms of the GNU General Public License as published by the Free Software 

13# Foundation, either version 3 of the License, or (at your option) any later version. 

14# Please note the following exception: `gms_preprocessing` depends on tqdm, which 

15# is distributed under the Mozilla Public Licence (MPL) v2.0 except for the files 

16# "tqdm/_tqdm.py", "setup.py", "README.rst", "MANIFEST.in" and ".gitignore". 

17# Details can be found here: https://github.com/tqdm/tqdm/blob/master/LICENCE. 

18# 

19# This program is distributed in the hope that it will be useful, but WITHOUT 

20# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 

21# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more 

22# details. 

23# 

24# You should have received a copy of the GNU Lesser General Public License along 

25# with this program. If not, see <http://www.gnu.org/licenses/>. 

26 

27"""Collection of helper functions for GeoMultiSens.""" 

28 

29import collections 

30import errno 

31import gzip 

32from zipfile import ZipFile 

33import itertools 

34import math 

35import operator 

36import os 

37import re 

38import shlex 

39import warnings 

40from datetime import datetime 

41 

42import numpy as np 

43import psycopg2 

44import shapely 

45from shapely.geometry import Polygon 

46 

47try: 

48 from osgeo import ogr 

49except ImportError: 

50 import ogr 

51from multiprocessing import sharedctypes 

52from matplotlib import pyplot as plt 

53from subprocess import Popen, PIPE 

54from xml.etree.ElementTree import QName 

55 

56from ..options.config import GMS_config as CFG 

57from . import database_tools as DB_T 

58from ..misc.definition_dicts import proc_chain 

59 

60from py_tools_ds.geo.coord_trafo import mapXY2imXY, reproject_shapelyGeometry 

61from py_tools_ds.geo.coord_calc import corner_coord_to_minmax 

62 

63__author__ = 'Daniel Scheffler' 

64 

65 

66def get_parentObjDict(): 

67 from ..algorithms.L1A_P import L1A_object 

68 from ..algorithms.L1B_P import L1B_object 

69 from ..algorithms.L1C_P import L1C_object 

70 from ..algorithms.L2A_P import L2A_object 

71 from ..algorithms.L2B_P import L2B_object 

72 from ..algorithms.L2C_P import L2C_object 

73 

74 return dict(L1A=L1A_object, 

75 L1B=L1B_object, 

76 L1C=L1C_object, 

77 L2A=L2A_object, 

78 L2B=L2B_object, 

79 L2C=L2C_object) 

80 

81 

82initArgsDict = {'L1A': (None,), 'L1B': (None,), 'L1C': (None,), 

83 'L2A': (None,), 'L2B': (None,), 'L2C': (None,)} 

84 

85 

86def silentremove(filename): 

87 # type: (str) -> None 

88 """Remove the given file without raising OSError exceptions, e.g. if the file does not exist.""" 

89 

90 try: 

91 os.remove(filename) 

92 except OSError as e: 

93 if e.errno != errno.ENOENT: # errno.ENOENT = no such file or directory 

94 raise # re-raise exception if a different error occured 

95 

96 

97def silentmkdir(path_dir_file): 

98 # type: (str) -> None 

99 while not os.path.isdir(os.path.dirname(path_dir_file)): 

100 try: 

101 os.makedirs(os.path.dirname(path_dir_file)) 

102 except OSError as e: 

103 if e.errno != 17: 

104 raise 

105 else: 

106 pass 

107 

108 

109def gzipfile(iname, oname, compression_level=1, blocksize=None): 

110 blocksize = blocksize if blocksize else 1 << 16 # 64kB 

111 with open(iname, 'rb') as f_in: 

112 f_out = gzip.open(oname, 'wb', compression_level) 

113 while True: 

114 block = f_in.read(blocksize) 

115 if block == '': 

116 break 

117 f_out.write(block) 

118 f_out.close() 

119 

120 

121def get_zipfile_namelist(path_zipfile): 

122 with ZipFile(path_zipfile) as zF: 

123 namelist = zF.namelist() 

124 return namelist 

125 

126 

127def ENVIfile_to_ENVIcompressed(inPath_hdr, outPath_hdr=None): 

128 inPath_bsq = os.path.splitext(inPath_hdr)[0] + '.bsq' 

129 outPath_bsq = os.path.splitext(outPath_hdr)[0] + '.bsq' if outPath_hdr else inPath_bsq 

130 gzipfile(inPath_bsq, outPath_bsq) 

131 with open(inPath_hdr, 'r') as inF: 

132 items = inF.read().split('\n') # FIXME use append write mode 

133 items.append('file compression = 1') 

134 with open(inPath_hdr, 'w') as outFile: 

135 [outFile.write(item + '\n') for item in items] 

136 # FIXME include file reordering 

137 

138 

139def subcall_with_output(cmd, no_stdout=False, no_stderr=False): 

140 """Execute external command and get its stdout, exitcode and stderr. 

141 :param cmd: a normal shell command including parameters 

142 """ 

143 

144 proc = Popen(shlex.split(cmd), stdout=None if no_stdout else PIPE, stderr=None if no_stderr else PIPE) 

145 out, err = proc.communicate() 

146 exitcode = proc.returncode 

147 

148 return out, exitcode, err 

149 

150 

151def sorted_nicely(iterable): 

152 """ Sort the given iterable in the way that humans expect. 

153 :param iterable: 

154 """ 

155 

156 def convert(text): return int(text) if text.isdigit() else text 

157 

158 def alphanum_key(key): return [convert(c) for c in re.split('([0-9]+)', key)] 

159 

160 return sorted(iterable, key=alphanum_key) 

161 

162 

163def safe_str(obj): 

164 """Return a safe string that will not cause any UnicodeEncodeError issues.""" 

165 return obj.encode('ascii', 'ignore').decode('ascii') 

166 

167 

168def is_proc_level_lower(current_lvl, target_lvl): 

169 # type: (str, str) -> bool 

170 """Return True if current_lvl is lower than target_lvl. 

171 

172 :param current_lvl: current processing level (to be tested) 

173 :param target_lvl: target processing level (refernce) 

174 """ 

175 return current_lvl is None or proc_chain.index(current_lvl) < proc_chain.index(target_lvl) 

176 

177 

178def convert_absPathArchive_to_GDALvsiPath(path_archive): 

179 assert path_archive.endswith(".zip") or path_archive.endswith(".tar") or path_archive.endswith(".tar.gz") or \ 

180 path_archive.endswith(".tgz"), """*%s archives are not yet supported. Please provide .zip, .tar, .tar.gz or 

181 .tgz archives.""" % os.path.splitext(path_archive)[1] 

182 gdal_prefix_dict = {'.zip': '/vsizip', '.tar': '/vsitar', '.tar.gz': '/vsitar', '.tgz': '/vsitar', 

183 '.gz': '/vsigzip'} 

184 file_suffix = os.path.splitext(path_archive)[1] 

185 file_suffix = '.tar.gz' if path_archive.endswith('.tar.gz') else file_suffix 

186 return os.path.join(gdal_prefix_dict[file_suffix], os.path.basename(path_archive)) 

187 

188 

189class mp_SharedNdarray(object): 

190 """ 

191 wrapper class, which collect all neccessary instances to make a numpy ndarray 

192 accessible as shared memory when using multiprocessing, it exposed the numpy 

193 array via three different views which can be used to access it globally 

194 

195 _init provides the mechanism to make this array available in each worker, 

196 best used using the provided __initializer__ 

197 """ 

198 

199 def __init__(self, dims): 

200 """ 

201 dims : tuple of dimensions which is used to instantiate a ndarray using np.zero 

202 """ 

203 # self.ct = np.ctypeslib.as_ctypes(np.zeros(dims, dtype=np.float)) # ctypes view on the new array 

204 self.ct = np.ctypeslib.as_ctypes(np.empty(dims, dtype=np.float)) # ctypes view on the new array 

205 self.sh = sharedctypes.Array(self.ct.type_, self.ct, lock=False) # shared memory view on the array 

206 self.np = np.ctypeslib.as_array(self.sh) # numpy view on the array 

207 

208 def _init(self, globals, name): 

209 """ 

210 This adds to globals while using 

211 the ctypes library view of [shared_ndaray instance].sh to make the numpy view 

212 of [shared_ndaray instance] globally available 

213 """ 

214 globals[name] = np.ctypeslib.as_array(self.sh) 

215 

216 

217def mp_initializer(globals, globs): 

218 """ 

219 globs shall be dict with name:value pairs, when executed value will be added to 

220 globals under the name name, if value provides a _init attribute this one is 

221 called instead. 

222 

223 This makes most sense when called as initializer in a multiprocessing pool, e.g.: 

224 Pool(initializer=__initializer__,initargs=(globs,)) 

225 :param globals: 

226 :param globs: 

227 """ 

228 

229 for name, value in globs.items(): 

230 try: 

231 value._init(globals, name) 

232 except AttributeError: 

233 globals[name] = value 

234 

235 

236def group_objects_by_attributes(object_list, *attributes): 

237 get_attr = operator.attrgetter(*attributes) 

238 return [list(g) for k, g in itertools.groupby(sorted(object_list, key=get_attr), get_attr)] 

239 

240 

241def group_tuples_by_keys_of_tupleElements(tuple_list, tupleElement_index, key): 

242 unique_vals = set([tup[tupleElement_index][key] for tup in tuple_list]) 

243 groups = [] 

244 for val in unique_vals: 

245 groups.append([tup for tup in tuple_list if tup[tupleElement_index][key] == val]) 

246 return groups 

247 

248 

249def group_dicts_by_key(dict_list, key): 

250 unique_vals = set([dic[key] for dic in dict_list]) 

251 groups = [[dic for dic in dict_list if dic[key] == val] for val in unique_vals] 

252 return groups 

253 

254 

255def cornerLonLat_to_postgreSQL_poly(CornerLonLat): 

256 """Converts a coordinate list [UL_LonLat, UR_LonLat, LL_LonLat, LR_LonLat] to a postgreSQL polygon. 

257 :param CornerLonLat: list of XY-coordinate tuples 

258 """ 

259 

260 return str(Polygon(CornerLonLat)) 

261 

262 

263def postgreSQL_poly_to_cornerLonLat(pGSQL_poly): 

264 # type: (str) -> list 

265 """Converts a postgreSQL polygon to a coordinate list [UL_LonLat, UR_LonLat, LL_LonLat, LR_LonLat]. 

266 :param pGSQL_poly: 

267 """ 

268 

269 if not pGSQL_poly.startswith('POLYGON'): 

270 raise ValueError("'pGSQL_poly' has to start with 'POLYGON...'. Got %s" % pGSQL_poly) 

271 fl = [float(i) for i in re.findall(r"[-+]?\d*\.\d+|\d+", pGSQL_poly)] 

272 CornerLonLat = [(fl[4], fl[5]), (fl[6], fl[7]), (fl[2], fl[3]), (fl[0], fl[1])] # UL,UR,LL,LR 

273 return CornerLonLat 

274 

275 

276def postgreSQL_geometry_to_postgreSQL_poly(geom): 

277 # type: (str) -> str 

278 connection = psycopg2.connect(CFG.conn_database) 

279 if connection is None: 

280 return 'database connection fault' 

281 cursor = connection.cursor() 

282 cursor.execute("SELECT ST_AsText('%s')" % geom) 

283 pGSQL_poly = cursor.fetchone()[0] 

284 cursor.close() 

285 connection.close() 

286 return pGSQL_poly 

287 

288 

289def postgreSQL_geometry_to_shapelyPolygon(wkb_hex): 

290 return shapely.wkb.loads(wkb_hex, hex=True) 

291 

292 

293def shapelyPolygon_to_postgreSQL_geometry(shapelyPoly): 

294 # type: (Polygon) -> str 

295 return shapelyPoly.wkb_hex # same result as "SELECT ST_GeomFromText('%s')" %shapelyPoly 

296 

297 

298def get_imageCoords_from_shapelyPoly(shapelyPoly, im_gt): 

299 # type: (Polygon,list) -> list 

300 """Converts each vertex coordinate of a shapely polygon into image coordinates corresponding to the given 

301 geotransform without respect to invalid image coordinates. Those must be filtered later. 

302 

303 :param shapelyPoly: <shapely.Polygon> 

304 :param im_gt: <list> the GDAL geotransform of the target image 

305 """ 

306 

307 def get_coordsArr(shpPoly): return np.swapaxes(np.array(shpPoly.exterior.coords.xy), 0, 1) 

308 coordsArr = get_coordsArr(shapelyPoly) 

309 imCoordsXY = [mapXY2imXY((X, Y), im_gt) for X, Y in coordsArr.tolist()] 

310 return imCoordsXY 

311 

312 

313def get_valid_arrSubsetBounds(arr_shape, tgt_bounds, buffer=0): 

314 # type: (tuple, tuple, float) -> tuple 

315 """Validates a given tuple of image coordinates, by checking if each coordinate is within a given bounding box and 

316 replacing invalid coordinates by valid ones. This function is needed in connection with 

317 get_arrSubsetBounds_from_shapelyPolyLonLat(). 

318 

319 :param arr_shape: <tuple of ints> the dimension of the bounding box where target coordinates are validated 

320 -> (rows, cols,bands) or (rows,cols) 

321 :param tgt_bounds: <tuple of floats> the target image coordinates in the form (xmin, xmax, ymin, ymax) 

322 :param buffer: <float> an optional buffer size (image pixel units) 

323 """ 

324 

325 rows, cols = arr_shape[:2] 

326 xmin, xmax, ymin, ymax = tgt_bounds 

327 if buffer: 

328 xmin, xmax, ymin, ymax = xmin - buffer, xmax + buffer, ymin - buffer, ymax + buffer 

329 

330 xmin = int(xmin) if int(xmin) >= 0 else 0 

331 xmax = math.ceil(xmax) if math.ceil(xmax) <= cols - 1 else cols - 1 

332 ymin = int(ymin) if int(ymin) >= 0 else 0 

333 ymax = math.ceil(ymax) if math.ceil(ymax) <= rows - 1 else rows - 1 

334 

335 outbounds = xmin, xmax, ymin, ymax 

336 return outbounds if (xmax > 0 or xmax < cols) and (ymax > 0 or ymax < rows) else None 

337 

338 

339def get_arrSubsetBounds_from_shapelyPolyLonLat(arr_shape, shpPolyLonLat, im_gt, im_prj, pixbuffer=0, 

340 ensure_valid_coords=True): 

341 # type: (tuple, Polygon, list, str, float, bool) -> tuple 

342 """Returns validated image coordinates, corresponding to the given shapely polygon. This function can be used to 

343 get the image coordines of e.g. MGRS tiles for a specific target image. 

344 

345 :param arr_shape: <tuple of ints> the dimensions of the target image -> (rows, cols,bands) or (rows,cols) 

346 :param shpPolyLonLat: <tuple of floats> the shapely polygon to get image coordinates for 

347 :param im_gt: <tuple> GDAL geotransform of the target image 

348 :param im_prj: <str> GDAL geographic projection (WKT string) of the target image 

349 (automatic reprojection is done if neccessary) 

350 :param pixbuffer: <float> an optional buffer size (image pixel units) 

351 :param ensure_valid_coords: <bool> whether to ensure that the returned values are all inside the original 

352 image bounding box 

353 """ 

354 

355 shpPolyImPrj = reproject_shapelyGeometry(shpPolyLonLat, 4326, im_prj) 

356 imCoordsXY = get_imageCoords_from_shapelyPoly(shpPolyImPrj, im_gt) 

357 bounds = corner_coord_to_minmax(imCoordsXY) 

358 outbounds = get_valid_arrSubsetBounds(arr_shape, bounds, buffer=pixbuffer) if ensure_valid_coords else bounds 

359 if outbounds: 

360 xmin, xmax, ymin, ymax = outbounds 

361 return xmin, xmax, ymin, ymax 

362 else: 

363 return None 

364 

365 

366def get_UL_LR_from_shapefile_features(path_shp): 

367 # type: (str) -> list 

368 """Returns a list of upper-left-lower-right coordinates ((ul,lr) tuples) for all features of a given shapefile. 

369 

370 :param path_shp: <str> the path of the shapefile 

371 """ 

372 

373 dataSource = ogr.Open(path_shp) 

374 layer = dataSource.GetLayer(0) 

375 ullr_list = [] 

376 for feature in layer: 

377 e = feature.geometry().GetEnvelope() 

378 ul = e[0], e[3] 

379 lr = e[1], e[2] 

380 ullr_list.append((ul, lr)) 

381 del dataSource, layer 

382 return ullr_list 

383 

384 

385def reorder_CornerLonLat(CornerLonLat): 

386 """Reorders corner coordinate lists from [UL,UR,LL,LR] to clockwise order: [UL,UR,LR,LL]""" 

387 

388 if len(CornerLonLat) > 4: 

389 warnings.warn('Only 4 of the given %s corner coordinates were respected.' % len(CornerLonLat)) 

390 return [CornerLonLat[0], CornerLonLat[1], CornerLonLat[3], CornerLonLat[2]] 

391 

392 

393def sceneID_to_trueDataCornerLonLat(scene_ID): 

394 """Returns a list of corner coordinates ordered like (UL,UR,LL,LR) corresponding to the given scene_ID by querying 

395 the database geometry field. """ 

396 

397 try: 

398 pgSQL_geom = DB_T.get_info_from_postgreSQLdb(CFG.conn_database, 'scenes_proc', 'bounds', 

399 {'sceneid': scene_ID})[0][0] 

400 except IndexError: 

401 pgSQL_geom = DB_T.get_info_from_postgreSQLdb(CFG.conn_database, 'scenes', 'bounds', {'id': scene_ID})[0][0] 

402 

403 assert shapely.wkb.loads(pgSQL_geom, hex=True).is_valid, \ 

404 'Database error: Received an invalid geometry from the postgreSQL database!' 

405 return postgreSQL_poly_to_cornerLonLat(postgreSQL_geometry_to_postgreSQL_poly(pgSQL_geom)) 

406 

407 

408def scene_ID_to_shapelyPolygon(scene_ID): 

409 # type: (int) -> Polygon 

410 """ 

411 Returns a LonLat shapely.Polygon() object corresponding to the given scene_ID. 

412 """ 

413 poly = Polygon(reorder_CornerLonLat(sceneID_to_trueDataCornerLonLat(scene_ID))) 

414 if not poly.is_valid: 

415 poly = poly.buffer(0) 

416 assert poly.is_valid 

417 return poly 

418 

419 

420def CornerLonLat_to_shapelyPoly(CornerLonLat): 

421 """Returns a shapely.Polygon() object based on the given coordinate list. """ 

422 poly = Polygon(reorder_CornerLonLat(CornerLonLat)) 

423 if not poly.is_valid: 

424 poly = poly.buffer(0) 

425 assert poly.is_valid 

426 return poly 

427 

428 

429def find_in_xml_root(namespace, xml_root, branch, *branches, findall=None): 

430 """ 

431 S2 xml helper function, search from root. Get part of xml. 

432 :param namespace: 

433 :param xml_root: 

434 :param branch: first branch, is combined with namespace 

435 :param branches: repeated find's along these parameters 

436 :param findall: if given, at final a findall 

437 :return: found xml object, None if nothing was found 

438 """ 

439 

440 buf = xml_root.find(str(QName(namespace, branch))) 

441 for br in branches: 

442 buf = buf.find(br) 

443 if findall is not None: 

444 buf = buf.findall(findall) 

445 return buf 

446 

447 

448def find_in_xml(xml, *branch): 

449 """ 

450 S2 xml helper function 

451 :param xml: xml object 

452 :param branch: iterate to branches using find 

453 :return: xml object, None if nothing was found 

454 """ 

455 

456 buf = xml 

457 for br in branch: 

458 buf = buf.find(br) 

459 return buf 

460 

461 

462def get_values_from_xml(leaf, dtype=np.float): 

463 """ 

464 S2 xml helper function 

465 :param leaf: xml object which is searched for VALUES tag which are then composed into a numpy array 

466 :param dtype: dtype of returned numpy array 

467 :return: numpy array 

468 """ 

469 

470 return np.array([ele.text.split(" ") for ele in leaf.findall("VALUES")], dtype=dtype) 

471 

472 

473def stack_detectors(inp): 

474 warnings.filterwarnings(action='ignore', message=r'Mean of empty slice') 

475 res = {bandId: np.nanmean(np.dstack(tuple(inp[bandId].values())), axis=2) for bandId, dat in inp.items()} 

476 warnings.filterwarnings(action='default', message=r'Mean of empty slice') 

477 return res 

478 

479 

480class Landsat_entityID_decrypter(object): 

481 SenDict = {'C8': 'OLI_TIRS', 'O8': 'OLI', 'T8': 'TIRS', 'E7': 'ETM+', 'T5': 'TM', 'T4': 'TM', 'M1': 'MSS1'} 

482 SatDict = {'C8': 'Landsat-8', 'O8': 'Landsat-8', 'T8': 'Landsat-8', 

483 'E7': 'Landsat-7', 'T5': 'Landsat-5', 'T4': 'Landsat-4', 'M1': 'Landsat-1'} 

484 

485 def __init__(self, entityID): 

486 self.entityID = entityID 

487 LDict = self.decrypt() 

488 

489 SatSen = LDict['sensor'] + LDict['satellite'] 

490 self.satellite = self.SatDict[SatSen] 

491 self.sensor = self.SenDict[SatSen] 

492 self.WRS_path = int(LDict['WRS_path']) 

493 self.WRS_row = int(LDict['WRS_row']) 

494 self.AcqDate = datetime.strptime(LDict['year'] + LDict['julian_day'], '%Y%j') 

495 if self.sensor == 'ETM+': 

496 self.SLCOnOff = 'SLC_ON' if self.AcqDate <= datetime.strptime('2003-05-31 23:46:34', '%Y-%m-%d %H:%M:%S') \ 

497 else 'SLC_OFF' 

498 self.sensorIncSLC = '%s_%s' % (self.sensor, self.SLCOnOff) 

499 self.ground_station_ID = LDict['ground_station_identifier'] 

500 self.archive_ver = LDict['archive_version_number'] 

501 

502 def decrypt(self): 

503 """LXSPPPRRRYYYYDDDGSIVV""" 

504 LDict = collections.OrderedDict() 

505 LDict['sensor'] = self.entityID[1] 

506 LDict['satellite'] = self.entityID[2] 

507 LDict['WRS_path'] = self.entityID[3:6] 

508 LDict['WRS_row'] = self.entityID[6:9] 

509 LDict['year'] = self.entityID[9:13] 

510 LDict['julian_day'] = self.entityID[13:16] 

511 LDict['ground_station_identifier'] = self.entityID[16:19] 

512 LDict['archive_version_number'] = self.entityID[19:21] 

513 return LDict 

514 

515 

516def subplot_2dline(XY_tuples, titles=None, shapetuple=None, grid=False): 

517 shapetuple = (1, len(XY_tuples)) if shapetuple is None else shapetuple 

518 assert titles is None or len(titles) == len( 

519 XY_tuples), 'List in titles keyword must have the same length as the passed XY_tuples.' 

520 fig = plt.figure(figsize=[ 

521 plt.figaspect([.5, ] if shapetuple[1] >= shapetuple[0] else [2., ]) * (2 if len(XY_tuples) == 1 else 3)]) 

522 for i, XY in enumerate(XY_tuples): 

523 ax = fig.add_subplot(shapetuple[0], shapetuple[1], i + 1) 

524 X, Y = XY 

525 ax.plot(X, Y, linestyle='-') 

526 if titles: 

527 ax.set_title(titles[i]) 

528 if grid: 

529 ax.grid(which='major', axis='both', linestyle='-') 

530 plt.show() 

531 

532 

533def subplot_imshow(ims, titles=None, shapetuple=None, grid=False): 

534 ims = [ims] if not isinstance(ims, list) else ims 

535 assert titles is None or len(titles) == len(ims), 'Error: Got more or less titles than images.' 

536 shapetuple = (1, len(ims)) if shapetuple is None else shapetuple 

537 fig, axes = plt.subplots(shapetuple[0], shapetuple[1], 

538 figsize=plt.figaspect(.5 if shapetuple[1] > shapetuple[0] else 2.) * 3) 

539 [axes[i].imshow(im, cmap='binary', interpolation='none', vmin=np.percentile(im, 2), vmax=np.percentile(im, 98)) for 

540 i, im in enumerate(ims)] 

541 if titles: 

542 [axes[i].set_title(titles[i]) for i in range(len(ims))] 

543 if grid: 

544 [axes[i].grid(which='major', axis='both', linestyle='-') for i in range(len(ims))] 

545 plt.show() 

546 

547 

548def subplot_3dsurface(ims, shapetuple=None): 

549 ims = [ims] if not isinstance(ims, list) else ims 

550 shapetuple = (1, len(ims)) if shapetuple is None else shapetuple 

551 fig = plt.figure(figsize=[plt.figaspect(.5 if shapetuple[1] >= shapetuple[0] else 2.) * 3]) 

552 for i, im in enumerate(ims): 

553 ax = fig.add_subplot(shapetuple[0], shapetuple[1], i + 1, projection='3d') 

554 x = np.arange(0, im.shape[0], 1) 

555 y = np.arange(0, im.shape[1], 1) 

556 X, Y = np.meshgrid(x, y) 

557 Z = im.reshape(X.shape) 

558 ax.plot_surface(X, Y, Z, cmap=plt.cm.hot) 

559 ax.contour(X, Y, Z, zdir='x', cmap=plt.cm.coolwarm, offset=0) 

560 ax.contour(X, Y, Z, zdir='y', cmap=plt.cm.coolwarm, offset=im.shape[1]) 

561 ax.set_xlabel('X') 

562 ax.set_ylabel('Y') 

563 ax.set_zlabel('Z') 

564 plt.show()