Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2 

3# spechomo, Spectral homogenization of multispectral satellite data 

4# 

5# Copyright (C) 2019-2021 

6# - Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de) 

7# - Helmholtz Centre Potsdam - GFZ German Research Centre for Geosciences Potsdam, 

8# Germany (https://www.gfz-potsdam.de/) 

9# 

10# This software was developed within the context of the GeoMultiSens project funded 

11# by the German Federal Ministry of Education and Research 

12# (project grant code: 01 IS 14 010 A-C). 

13# 

14# Licensed under the Apache License, Version 2.0 (the "License"); 

15# you may not use this file except in compliance with the License. 

16# You may obtain a copy of the License at 

17# 

18# http://www.apache.org/licenses/LICENSE-2.0 

19# 

20# Please note the following exception: `spechomo` depends on tqdm, which is 

21# distributed under the Mozilla Public Licence (MPL) v2.0 except for the files 

22# "tqdm/_tqdm.py", "setup.py", "README.rst", "MANIFEST.in" and ".gitignore". 

23# Details can be found here: https://github.com/tqdm/tqdm/blob/master/LICENCE. 

24# 

25# Unless required by applicable law or agreed to in writing, software 

26# distributed under the License is distributed on an "AS IS" BASIS, 

27# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

28# See the License for the specific language governing permissions and 

29# limitations under the License. 

30 

31import json 

32import os 

33import re 

34from collections import OrderedDict 

35from typing import Union, List # noqa F401 # flake8 issue 

36from tqdm import tqdm 

37 

38import numpy as np 

39from geoarray import GeoArray 

40from pandas import DataFrame 

41from pandas.plotting import scatter_matrix 

42from pyrsr import RSR 

43 

44from .utils import im2spectra 

45 

46 

47class TrainingData(object): 

48 """Class for analyzing statistical relations between a pair of machine learning training data cubes.""" 

49 

50 def __init__(self, im_X, im_Y, test_size): 

51 # type: (Union[GeoArray, np.ndarray], Union[GeoArray, np.ndarray], Union[float, int]) -> None 

52 """Get instance of TrainingData. 

53 

54 :param im_X: input image X 

55 :param im_Y: input image Y 

56 :param test_size: test size (proportion as float between 0 and 1) or absolute value as integer 

57 """ 

58 from sklearn.model_selection import train_test_split # avoids static TLS error here 

59 

60 self.im_X = GeoArray(im_X) 

61 self.im_Y = GeoArray(im_Y) 

62 

63 # Set spectra (3D to 2D conversion) 

64 self.spectra_X = im2spectra(self.im_X) 

65 self.spectra_Y = im2spectra(self.im_Y) 

66 

67 # Set train and test variables 

68 # NOTE: If random_state is set to an Integer, train_test_split will always select the same 'pseudo-random' set 

69 # of the input data. 

70 self.train_X, self.test_X, self.train_Y, self.test_Y = \ 

71 train_test_split(self.spectra_X, self.spectra_Y, test_size=test_size, shuffle=True, random_state=0) 

72 

73 def plot_scatter_matrix(self, figsize=(15, 15), mode='intersensor'): 

74 # TODO complete this function 

75 from matplotlib import pyplot as plt 

76 

77 train_X = self.train_X[np.random.choice(self.train_X.shape[0], 1000, replace=False), :] 

78 train_Y = self.train_Y[np.random.choice(self.train_Y.shape[0], 1000, replace=False), :] 

79 

80 if mode == 'intersensor': 

81 import seaborn 

82 

83 fig, axes = plt.subplots(train_X.shape[1], train_Y.shape[1], 

84 figsize=(25, 9), sharex='all', sharey='all') 

85 # fig.suptitle('Correlation of %s and %s bands' % (self.src_cube.satellite, self.tgt_cube.satellite), 

86 # size=25) 

87 

88 color = seaborn.hls_palette(13) 

89 

90 for i, ax in zip(range(train_X.shape[1]), axes.flatten()): 

91 for j, ax in zip(range(train_Y.shape[1]), axes.flatten()): 

92 axes[i, j].scatter(train_X[:, j], train_Y[:, i], c=color[j], label=str(j)) 

93 # axes[i, j].set_xlim(-0.1, 1.1) 

94 # axes[i, j].set_ylim(-0.1, 1.1) 

95 # if j == 8: 

96 # axes[5, j].set_xlabel('S2 B8A\n' + str(metadata_s2['Bands_S2'][j]) + ' nm', size=10) 

97 # elif j in range(9, 13): 

98 # axes[5, j].set_xlabel('S2 B' + str(j) + '\n' + str(metadata_s2['Bands_S2'][j]) + ' nm', 

99 # size=10) 

100 # else: 

101 # axes[5, j].set_xlabel('S2 B' + str(j + 1) + '\n' + str(metadata_s2['Bands_S2'][j]) + ' nm', 

102 # size=10) 

103 # axes[i, 0].set_ylabel( 

104 # 'S3 SLSTR B' + str(6 - i) + '\n' + str(metadata_s3['Bands_S3'][5 - i]) + ' nm', 

105 # size=10) 

106 # axes[4, j].set_xticks(np.arange(0, 1.2, 0.2)) 

107 # axes[i, j].plot([0, 1], [0, 1], c='red') 

108 

109 else: 

110 df = DataFrame(train_X, columns=['Band %s' % b for b in range(1, self.im_X.bands + 1)]) 

111 scatter_matrix(df, figsize=figsize, marker='.', hist_kwds={'bins': 50}, s=30, alpha=0.8) 

112 plt.suptitle('Image X band to band correlation') 

113 

114 df = DataFrame(train_Y, columns=['Band %s' % b for b in range(1, self.im_Y.bands + 1)]) 

115 scatter_matrix(df, figsize=figsize, marker='.', hist_kwds={'bins': 50}, s=30, alpha=0.8) 

116 plt.suptitle('Image Y band to band correlation') 

117 

118 def plot_scattermatrix(self): 

119 # TODO complete this function 

120 import seaborn 

121 from matplotlib import pyplot as plt 

122 

123 fig, axes = plt.subplots(self.im_X.data.bands, self.im_Y.data.bands, 

124 figsize=(25, 9), sharex='all', sharey='all') 

125 fig.suptitle('Correlation of %s and %s bands' % (self.im_X.satellite, self.im_Y.satellite), size=25) 

126 

127 color = seaborn.hls_palette(13) 

128 

129 for i, ax in zip(range(6), axes.flatten()): 

130 for j, ax in zip(range(13), axes.flatten()): 

131 axes[i, j].scatter(self.train_X[:, j], self.train_Y[:, 5 - i], c=color[j], label=str(j)) 

132 axes[i, j].set_xlim(-0.1, 1.1) 

133 axes[i, j].set_ylim(-0.1, 1.1) 

134 # if j == 8: 

135 # axes[5, j].set_xlabel('S2 B8A\n' + str(metadata_s2['Bands_S2'][j]) + ' nm', size=10) 

136 # elif j in range(9, 13): 

137 # axes[5, j].set_xlabel('S2 B' + str(j) + '\n' + str(metadata_s2['Bands_S2'][j]) + ' nm', size=10) 

138 # else: 

139 # axes[5, j].set_xlabel('S2 B' + str(j + 1) + '\n' + str(metadata_s2['Bands_S2'][j]) + ' nm', 

140 # size=10) 

141 # axes[i, 0].set_ylabel('S3 SLSTR B' + str(6 - i) + '\n' + str(metadata_s3['Bands_S3'][5 - i]) + ' nm', 

142 # size=10) 

143 axes[4, j].set_xticks(np.arange(0, 1.2, 0.2)) 

144 axes[i, j].plot([0, 1], [0, 1], c='red') 

145 

146 def show_band_scatterplot(self, band_src_im, band_tgt_im): 

147 # TODO complete this function 

148 from scipy.stats import gaussian_kde 

149 from matplotlib import pyplot as plt 

150 

151 x = self.im_X.data[band_src_im].flatten()[:10000] 

152 y = self.im_Y.data[band_tgt_im].flatten()[:10000] 

153 

154 # Calculate the point density 

155 xy = np.vstack([x, y]) 

156 z = gaussian_kde(xy)(xy) 

157 

158 plt.figure(figsize=(15, 15)) 

159 plt.scatter(x, y, c=z, s=30, edgecolor='none') 

160 plt.show() 

161 

162 

163class RefCube(object): 

164 """Data model class for reference cubes holding the training data for later fitted machine learning classifiers.""" 

165 

166 def __init__(self, filepath='', satellite='', sensor='', LayerBandsAssignment=None): 

167 # type: (str, str, str, list) -> None 

168 """Get instance of RefCube. 

169 

170 :param filepath: file path for importing an existing reference cube from disk 

171 :param satellite: the satellite for which the reference cube holds its spectral data 

172 :param sensor: the sensor for which the reference cube holds its spectral data 

173 :param LayerBandsAssignment: the LayerBandsAssignment for which the reference cube holds its spectral data 

174 """ 

175 # privates 

176 self._col_imName_dict = dict() 

177 self._wavelenths = [] 

178 

179 # defaults 

180 self.data = GeoArray(np.empty((0, 0, len(LayerBandsAssignment) if LayerBandsAssignment else 0)), 

181 nodata=-9999) 

182 self.srcImNames = [] 

183 

184 # args/ kwargs 

185 self.filepath = filepath 

186 self.satellite = satellite 

187 self.sensor = sensor 

188 self.LayerBandsAssignment = LayerBandsAssignment or [] 

189 

190 if filepath: 

191 self.read_data_from_disk(filepath) 

192 

193 if self.satellite and self.sensor and self.LayerBandsAssignment: 

194 self._add_bandnames_wavelenghts_to_meta() 

195 

196 def _add_bandnames_wavelenghts_to_meta(self): 

197 # set bandnames 

198 self.data.bandnames = ['Band %s' % b for b in self.LayerBandsAssignment] 

199 

200 # set wavelengths 

201 self.data.metadata.band_meta['wavelength'] = self.wavelengths 

202 

203 @property 

204 def n_images(self): 

205 """Return the number training images from which the reference cube contains spectral samples.""" 

206 return self.data.shape[1] 

207 

208 @property 

209 def n_signatures(self): 

210 """Return the number spectral signatures per training image included in the reference cube.""" 

211 return self.data.shape[0] 

212 

213 @property 

214 def n_clusters(self): 

215 """Return the number spectral clusters used for clustering source images for the reference cube.""" 

216 if self.filepath: 

217 identifier = re.search('refcube__(.*).bsq', os.path.basename(self.filepath)).group(1) 

218 return int(identifier.split('__')[2].split('nclust')[1]) 

219 

220 @property 

221 def n_signatures_per_cluster(self): 

222 if self.n_clusters: 

223 return self.n_signatures // self.n_clusters 

224 

225 @property 

226 def col_imName_dict(self): 

227 # type: () -> OrderedDict 

228 """Return an ordered dict containing the file base names of the original training images for each column.""" 

229 return OrderedDict((col, imName) for col, imName in zip(range(self.n_images), self.srcImNames)) 

230 

231 @col_imName_dict.setter 

232 def col_imName_dict(self, col_imName_dict): 

233 # type: (dict) -> None 

234 self._col_imName_dict = col_imName_dict 

235 self.srcImNames = list(col_imName_dict.values()) 

236 

237 @property 

238 def wavelengths(self): 

239 if not self._wavelenths and self.satellite and self.sensor and self.LayerBandsAssignment: 

240 self._wavelenths = list(RSR(self.satellite, self.sensor, 

241 LayerBandsAssignment=self.LayerBandsAssignment).wvl) 

242 

243 return self._wavelenths 

244 

245 @wavelengths.setter 

246 def wavelengths(self, wavelengths): 

247 self._wavelenths = wavelengths 

248 

249 def add_refcube_array(self, refcube_array, src_imnames, LayerBandsAssignment): 

250 # type: (Union[str, np.ndarray], list, list) -> None 

251 """Add the given given array to the RefCube instance. 

252 

253 :param refcube_array: 3D array or file path of the reference cube to be added 

254 (spectral samples /signatures x training images x spectral bands) 

255 :param src_imnames: list of training image file base names from which the given cube received data 

256 :param LayerBandsAssignment: LayerBandsAssignment of the spectral bands of the given 3D array 

257 :return: 

258 """ 

259 # validation 

260 assert LayerBandsAssignment == self.LayerBandsAssignment, \ 

261 "%s != %s" % (LayerBandsAssignment, self.LayerBandsAssignment) 

262 

263 if self.data.size: 

264 new_cube = np.hstack([self.data, refcube_array]) 

265 self.data = GeoArray(new_cube, nodata=self.data.nodata) 

266 else: 

267 self.data = GeoArray(refcube_array, nodata=self.data.nodata) 

268 

269 self.srcImNames.extend(src_imnames) 

270 

271 def add_spectra(self, spectra, src_imname, LayerBandsAssignment): 

272 # type: (np.ndarray, str, list) -> None 

273 """Add a set of spectral signatures to the reference cube. 

274 

275 :param spectra: 2D numpy array with rows: spectral samples / columns: spectral information (bands) 

276 :param src_imname: image basename of the source hyperspectral image 

277 :param LayerBandsAssignment: LayerBandsAssignment for the spectral dimension of the passed spectra, 

278 e.g., ['1', '2', '3', '4', '5', '6L', '6H', '7', '8'] 

279 """ 

280 # validation 

281 assert LayerBandsAssignment == self.LayerBandsAssignment, \ 

282 "%s != %s" % (LayerBandsAssignment, self.LayerBandsAssignment) 

283 

284 # reshape 2D spectra array to one image column (refcube is an image with spectral information in the 3rd dim.) 

285 im_col = spectra.reshape((spectra.shape[0], 1, spectra.shape[1])) 

286 

287 meta = self.data.metadata # needs to be copied to the new GeoArray 

288 

289 if self.data.size: 

290 # validation 

291 if spectra.shape[0] != self.data.shape[0]: 

292 raise ValueError('The number of signatures in the given spectra array does not match the dimensions of ' 

293 'the reference cube.') 

294 

295 # append spectra to existing reference cube 

296 new_cube = np.hstack([self.data, im_col]) 

297 self.data = GeoArray(new_cube, nodata=self.data.nodata) 

298 

299 else: 

300 self.data = GeoArray(im_col, nodata=self.data.nodata) 

301 

302 # copy previous metadata to the new GeoArray instance 

303 self.data.metadata = meta 

304 

305 # add source image name to list of image names 

306 self.srcImNames.append(src_imname) 

307 

308 @property 

309 def metadata(self): 

310 """Return an ordered dictionary holding the metadata of the reference cube.""" 

311 attrs2include = ['satellite', 'sensor', 'filepath', 'n_signatures', 'n_images', 'n_clusters', 

312 'n_signatures_per_cluster', 'col_imName_dict', 'LayerBandsAssignment', 'wavelengths'] 

313 return OrderedDict((k, getattr(self, k)) for k in attrs2include) 

314 

315 def get_band_combination(self, tgt_LBA): 

316 # type: (List[str]) -> GeoArray 

317 """Get an array according to the bands order given by a target LayerBandsAssignment. 

318 

319 :param tgt_LBA: target LayerBandsAssignment 

320 :return: 

321 """ 

322 if tgt_LBA != self.LayerBandsAssignment: 

323 cur_LBA_dict = dict(zip(self.LayerBandsAssignment, range(len(self.LayerBandsAssignment)))) 

324 tgt_bIdxList = [cur_LBA_dict[lr] for lr in tgt_LBA] 

325 

326 return GeoArray(np.take(self.data, tgt_bIdxList, axis=2), nodata=self.data.nodata) 

327 else: 

328 return self.data 

329 

330 def get_spectra_dataframe(self, tgt_LBA): 

331 # type: (List[str]) -> DataFrame 

332 """Return a pandas.DataFrame [sample x band] according to the given LayerBandsAssignment. 

333 

334 :param tgt_LBA: target LayerBandsAssignment 

335 :return: 

336 """ 

337 imdata = self.get_band_combination(tgt_LBA) 

338 spectra = im2spectra(imdata) 

339 df = DataFrame(spectra, columns=['B%s' % band for band in tgt_LBA]) 

340 

341 return df 

342 

343 def rearrange_layers(self, tgt_LBA): 

344 # type: (List[str]) -> None 

345 """Rearrange the spectral bands of the reference cube according to the given LayerBandsAssignment. 

346 

347 :param tgt_LBA: target LayerBandsAssignment 

348 """ 

349 self.data = self.get_band_combination(tgt_LBA) 

350 self.LayerBandsAssignment = tgt_LBA 

351 

352 def read_data_from_disk(self, filepath): 

353 self.data = GeoArray(filepath) 

354 

355 with open(os.path.splitext(filepath)[0] + '.meta', 'r') as metaF: 

356 meta = json.load(metaF) 

357 for k, v in meta.items(): 

358 if k in ['n_signatures', 'n_images', 'n_clusters', 'n_signatures_per_cluster']: 

359 continue # skip pure getters 

360 else: 

361 setattr(self, k, v) 

362 

363 def save(self, path_out, fmt='ENVI'): 

364 # type: (str, str) -> None 

365 """Save the reference cube to disk. 

366 

367 :param path_out: output path on disk 

368 :param fmt: output format as GDAL format code 

369 :return: 

370 """ 

371 self.filepath = self.filepath or path_out 

372 self.data.save(out_path=path_out, fmt=fmt) 

373 

374 # save metadata as JSON file 

375 with open(os.path.splitext(path_out)[0] + '.meta', 'w') as metaF: 

376 json.dump(self.metadata.copy(), metaF, separators=(',', ': '), indent=4) 

377 

378 def _get_spectra_by_label_imname(self, cluster_label, image_basename, n_spectra2get=100, random_state=0): 

379 cluster_start_pos_all = list(range(0, self.n_signatures, self.n_signatures_per_cluster)) 

380 cluster_start_pos = cluster_start_pos_all[cluster_label] 

381 spectra = self.data[cluster_start_pos: cluster_start_pos + self.n_signatures_per_cluster, 

382 self.srcImNames.index(image_basename)] 

383 idxs_specIncl = np.random.RandomState(seed=random_state).choice(range(self.n_signatures_per_cluster), 

384 n_spectra2get) 

385 return spectra[idxs_specIncl, :] 

386 

387 def plot_sample_spectra(self, image_basename, cluster_label='all', include_mean_spectrum=True, 

388 include_median_spectrum=True, ncols=5, **kw_fig): 

389 # type: (Union[str, int, List], str, bool, bool, int, dict) -> 'plt.figure' 

390 from matplotlib import pyplot as plt 

391 

392 if isinstance(cluster_label, int): 

393 lbls2plot = [cluster_label] 

394 elif isinstance(cluster_label, list): 

395 lbls2plot = cluster_label 

396 elif cluster_label == 'all': 

397 lbls2plot = list(range(self.n_clusters)) 

398 else: 

399 raise ValueError(cluster_label) 

400 

401 # create a single plot 

402 if len(lbls2plot) == 1: 

403 if cluster_label == 'all': 

404 cluster_label = 0 

405 

406 fig, axes = plt.figure(), None 

407 spectra = self._get_spectra_by_label_imname(cluster_label, image_basename, 100) 

408 for i in range(100): 

409 plt.plot(self.wavelengths, spectra[i, :]) 

410 

411 plt.xlabel('wavelength [nm]') 

412 plt.ylabel('%s %s\nreflectance [0-10000]' % (self.satellite, self.sensor)) 

413 plt.title('Cluster #%s' % cluster_label) 

414 

415 if include_mean_spectrum: 

416 plt.plot(self.wavelengths, np.mean(spectra, axis=0), c='black', lw=3) 

417 if include_median_spectrum: 

418 plt.plot(self.wavelengths, np.median(spectra, axis=0), '--', c='black', lw=3) 

419 

420 # create a plot with multiple subplots 

421 else: 

422 nplots = len(lbls2plot) 

423 ncols = nplots if nplots < ncols else ncols 

424 nrows = nplots // ncols if not nplots % ncols else nplots // ncols + 1 

425 figsize = (4 * ncols, 3 * nrows) 

426 fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, sharex='all', sharey='all', 

427 **kw_fig) 

428 

429 for lbl, ax in tqdm(zip(lbls2plot, axes.flatten()), total=nplots): 

430 spectra = self._get_spectra_by_label_imname(lbl, image_basename, 100) 

431 

432 for i in range(100): 

433 ax.plot(self.wavelengths, spectra[i, :], lw=1) 

434 

435 if include_mean_spectrum: 

436 ax.plot(self.wavelengths, np.mean(spectra, axis=0), c='black', lw=2) 

437 if include_median_spectrum: 

438 ax.plot(self.wavelengths, np.median(spectra, axis=0), '--', c='black', lw=3) 

439 

440 ax.grid(lw=0.2) 

441 ax.set_ylim(0, 10000) 

442 

443 if ax.get_subplotspec().is_last_row(): 

444 ax.set_xlabel('wavelength [nm]') 

445 if ax.get_subplotspec().is_first_col(): 

446 ax.set_ylabel('%s %s\nreflectance [0-10000]' % (self.satellite, self.sensor)) 

447 ax.set_title('Cluster #%s' % lbl) 

448 

449 fig.suptitle("Refcube spectra from image '%s':" % image_basename, fontsize=15) 

450 plt.tight_layout(rect=(0, 0, 1, .95)) 

451 plt.show() 

452 

453 return fig