Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
3# spechomo, Spectral homogenization of multispectral satellite data
4#
5# Copyright (C) 2019-2021
6# - Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de)
7# - Helmholtz Centre Potsdam - GFZ German Research Centre for Geosciences Potsdam,
8# Germany (https://www.gfz-potsdam.de/)
9#
10# This software was developed within the context of the GeoMultiSens project funded
11# by the German Federal Ministry of Education and Research
12# (project grant code: 01 IS 14 010 A-C).
13#
14# Licensed under the Apache License, Version 2.0 (the "License");
15# you may not use this file except in compliance with the License.
16# You may obtain a copy of the License at
17#
18# http://www.apache.org/licenses/LICENSE-2.0
19#
20# Please note the following exception: `spechomo` depends on tqdm, which is
21# distributed under the Mozilla Public Licence (MPL) v2.0 except for the files
22# "tqdm/_tqdm.py", "setup.py", "README.rst", "MANIFEST.in" and ".gitignore".
23# Details can be found here: https://github.com/tqdm/tqdm/blob/master/LICENCE.
24#
25# Unless required by applicable law or agreed to in writing, software
26# distributed under the License is distributed on an "AS IS" BASIS,
27# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28# See the License for the specific language governing permissions and
29# limitations under the License.
31import json
32import os
33import re
34from collections import OrderedDict
35from typing import Union, List # noqa F401 # flake8 issue
36from tqdm import tqdm
38import numpy as np
39from geoarray import GeoArray
40from pandas import DataFrame
41from pandas.plotting import scatter_matrix
42from pyrsr import RSR
44from .utils import im2spectra
47class TrainingData(object):
48 """Class for analyzing statistical relations between a pair of machine learning training data cubes."""
50 def __init__(self, im_X, im_Y, test_size):
51 # type: (Union[GeoArray, np.ndarray], Union[GeoArray, np.ndarray], Union[float, int]) -> None
52 """Get instance of TrainingData.
54 :param im_X: input image X
55 :param im_Y: input image Y
56 :param test_size: test size (proportion as float between 0 and 1) or absolute value as integer
57 """
58 from sklearn.model_selection import train_test_split # avoids static TLS error here
60 self.im_X = GeoArray(im_X)
61 self.im_Y = GeoArray(im_Y)
63 # Set spectra (3D to 2D conversion)
64 self.spectra_X = im2spectra(self.im_X)
65 self.spectra_Y = im2spectra(self.im_Y)
67 # Set train and test variables
68 # NOTE: If random_state is set to an Integer, train_test_split will always select the same 'pseudo-random' set
69 # of the input data.
70 self.train_X, self.test_X, self.train_Y, self.test_Y = \
71 train_test_split(self.spectra_X, self.spectra_Y, test_size=test_size, shuffle=True, random_state=0)
73 def plot_scatter_matrix(self, figsize=(15, 15), mode='intersensor'):
74 # TODO complete this function
75 from matplotlib import pyplot as plt
77 train_X = self.train_X[np.random.choice(self.train_X.shape[0], 1000, replace=False), :]
78 train_Y = self.train_Y[np.random.choice(self.train_Y.shape[0], 1000, replace=False), :]
80 if mode == 'intersensor':
81 import seaborn
83 fig, axes = plt.subplots(train_X.shape[1], train_Y.shape[1],
84 figsize=(25, 9), sharex='all', sharey='all')
85 # fig.suptitle('Correlation of %s and %s bands' % (self.src_cube.satellite, self.tgt_cube.satellite),
86 # size=25)
88 color = seaborn.hls_palette(13)
90 for i, ax in zip(range(train_X.shape[1]), axes.flatten()):
91 for j, ax in zip(range(train_Y.shape[1]), axes.flatten()):
92 axes[i, j].scatter(train_X[:, j], train_Y[:, i], c=color[j], label=str(j))
93 # axes[i, j].set_xlim(-0.1, 1.1)
94 # axes[i, j].set_ylim(-0.1, 1.1)
95 # if j == 8:
96 # axes[5, j].set_xlabel('S2 B8A\n' + str(metadata_s2['Bands_S2'][j]) + ' nm', size=10)
97 # elif j in range(9, 13):
98 # axes[5, j].set_xlabel('S2 B' + str(j) + '\n' + str(metadata_s2['Bands_S2'][j]) + ' nm',
99 # size=10)
100 # else:
101 # axes[5, j].set_xlabel('S2 B' + str(j + 1) + '\n' + str(metadata_s2['Bands_S2'][j]) + ' nm',
102 # size=10)
103 # axes[i, 0].set_ylabel(
104 # 'S3 SLSTR B' + str(6 - i) + '\n' + str(metadata_s3['Bands_S3'][5 - i]) + ' nm',
105 # size=10)
106 # axes[4, j].set_xticks(np.arange(0, 1.2, 0.2))
107 # axes[i, j].plot([0, 1], [0, 1], c='red')
109 else:
110 df = DataFrame(train_X, columns=['Band %s' % b for b in range(1, self.im_X.bands + 1)])
111 scatter_matrix(df, figsize=figsize, marker='.', hist_kwds={'bins': 50}, s=30, alpha=0.8)
112 plt.suptitle('Image X band to band correlation')
114 df = DataFrame(train_Y, columns=['Band %s' % b for b in range(1, self.im_Y.bands + 1)])
115 scatter_matrix(df, figsize=figsize, marker='.', hist_kwds={'bins': 50}, s=30, alpha=0.8)
116 plt.suptitle('Image Y band to band correlation')
118 def plot_scattermatrix(self):
119 # TODO complete this function
120 import seaborn
121 from matplotlib import pyplot as plt
123 fig, axes = plt.subplots(self.im_X.data.bands, self.im_Y.data.bands,
124 figsize=(25, 9), sharex='all', sharey='all')
125 fig.suptitle('Correlation of %s and %s bands' % (self.im_X.satellite, self.im_Y.satellite), size=25)
127 color = seaborn.hls_palette(13)
129 for i, ax in zip(range(6), axes.flatten()):
130 for j, ax in zip(range(13), axes.flatten()):
131 axes[i, j].scatter(self.train_X[:, j], self.train_Y[:, 5 - i], c=color[j], label=str(j))
132 axes[i, j].set_xlim(-0.1, 1.1)
133 axes[i, j].set_ylim(-0.1, 1.1)
134 # if j == 8:
135 # axes[5, j].set_xlabel('S2 B8A\n' + str(metadata_s2['Bands_S2'][j]) + ' nm', size=10)
136 # elif j in range(9, 13):
137 # axes[5, j].set_xlabel('S2 B' + str(j) + '\n' + str(metadata_s2['Bands_S2'][j]) + ' nm', size=10)
138 # else:
139 # axes[5, j].set_xlabel('S2 B' + str(j + 1) + '\n' + str(metadata_s2['Bands_S2'][j]) + ' nm',
140 # size=10)
141 # axes[i, 0].set_ylabel('S3 SLSTR B' + str(6 - i) + '\n' + str(metadata_s3['Bands_S3'][5 - i]) + ' nm',
142 # size=10)
143 axes[4, j].set_xticks(np.arange(0, 1.2, 0.2))
144 axes[i, j].plot([0, 1], [0, 1], c='red')
146 def show_band_scatterplot(self, band_src_im, band_tgt_im):
147 # TODO complete this function
148 from scipy.stats import gaussian_kde
149 from matplotlib import pyplot as plt
151 x = self.im_X.data[band_src_im].flatten()[:10000]
152 y = self.im_Y.data[band_tgt_im].flatten()[:10000]
154 # Calculate the point density
155 xy = np.vstack([x, y])
156 z = gaussian_kde(xy)(xy)
158 plt.figure(figsize=(15, 15))
159 plt.scatter(x, y, c=z, s=30, edgecolor='none')
160 plt.show()
163class RefCube(object):
164 """Data model class for reference cubes holding the training data for later fitted machine learning classifiers."""
166 def __init__(self, filepath='', satellite='', sensor='', LayerBandsAssignment=None):
167 # type: (str, str, str, list) -> None
168 """Get instance of RefCube.
170 :param filepath: file path for importing an existing reference cube from disk
171 :param satellite: the satellite for which the reference cube holds its spectral data
172 :param sensor: the sensor for which the reference cube holds its spectral data
173 :param LayerBandsAssignment: the LayerBandsAssignment for which the reference cube holds its spectral data
174 """
175 # privates
176 self._col_imName_dict = dict()
177 self._wavelenths = []
179 # defaults
180 self.data = GeoArray(np.empty((0, 0, len(LayerBandsAssignment) if LayerBandsAssignment else 0)),
181 nodata=-9999)
182 self.srcImNames = []
184 # args/ kwargs
185 self.filepath = filepath
186 self.satellite = satellite
187 self.sensor = sensor
188 self.LayerBandsAssignment = LayerBandsAssignment or []
190 if filepath:
191 self.read_data_from_disk(filepath)
193 if self.satellite and self.sensor and self.LayerBandsAssignment:
194 self._add_bandnames_wavelenghts_to_meta()
196 def _add_bandnames_wavelenghts_to_meta(self):
197 # set bandnames
198 self.data.bandnames = ['Band %s' % b for b in self.LayerBandsAssignment]
200 # set wavelengths
201 self.data.metadata.band_meta['wavelength'] = self.wavelengths
203 @property
204 def n_images(self):
205 """Return the number training images from which the reference cube contains spectral samples."""
206 return self.data.shape[1]
208 @property
209 def n_signatures(self):
210 """Return the number spectral signatures per training image included in the reference cube."""
211 return self.data.shape[0]
213 @property
214 def n_clusters(self):
215 """Return the number spectral clusters used for clustering source images for the reference cube."""
216 if self.filepath:
217 identifier = re.search('refcube__(.*).bsq', os.path.basename(self.filepath)).group(1)
218 return int(identifier.split('__')[2].split('nclust')[1])
220 @property
221 def n_signatures_per_cluster(self):
222 if self.n_clusters:
223 return self.n_signatures // self.n_clusters
225 @property
226 def col_imName_dict(self):
227 # type: () -> OrderedDict
228 """Return an ordered dict containing the file base names of the original training images for each column."""
229 return OrderedDict((col, imName) for col, imName in zip(range(self.n_images), self.srcImNames))
231 @col_imName_dict.setter
232 def col_imName_dict(self, col_imName_dict):
233 # type: (dict) -> None
234 self._col_imName_dict = col_imName_dict
235 self.srcImNames = list(col_imName_dict.values())
237 @property
238 def wavelengths(self):
239 if not self._wavelenths and self.satellite and self.sensor and self.LayerBandsAssignment:
240 self._wavelenths = list(RSR(self.satellite, self.sensor,
241 LayerBandsAssignment=self.LayerBandsAssignment).wvl)
243 return self._wavelenths
245 @wavelengths.setter
246 def wavelengths(self, wavelengths):
247 self._wavelenths = wavelengths
249 def add_refcube_array(self, refcube_array, src_imnames, LayerBandsAssignment):
250 # type: (Union[str, np.ndarray], list, list) -> None
251 """Add the given given array to the RefCube instance.
253 :param refcube_array: 3D array or file path of the reference cube to be added
254 (spectral samples /signatures x training images x spectral bands)
255 :param src_imnames: list of training image file base names from which the given cube received data
256 :param LayerBandsAssignment: LayerBandsAssignment of the spectral bands of the given 3D array
257 :return:
258 """
259 # validation
260 assert LayerBandsAssignment == self.LayerBandsAssignment, \
261 "%s != %s" % (LayerBandsAssignment, self.LayerBandsAssignment)
263 if self.data.size:
264 new_cube = np.hstack([self.data, refcube_array])
265 self.data = GeoArray(new_cube, nodata=self.data.nodata)
266 else:
267 self.data = GeoArray(refcube_array, nodata=self.data.nodata)
269 self.srcImNames.extend(src_imnames)
271 def add_spectra(self, spectra, src_imname, LayerBandsAssignment):
272 # type: (np.ndarray, str, list) -> None
273 """Add a set of spectral signatures to the reference cube.
275 :param spectra: 2D numpy array with rows: spectral samples / columns: spectral information (bands)
276 :param src_imname: image basename of the source hyperspectral image
277 :param LayerBandsAssignment: LayerBandsAssignment for the spectral dimension of the passed spectra,
278 e.g., ['1', '2', '3', '4', '5', '6L', '6H', '7', '8']
279 """
280 # validation
281 assert LayerBandsAssignment == self.LayerBandsAssignment, \
282 "%s != %s" % (LayerBandsAssignment, self.LayerBandsAssignment)
284 # reshape 2D spectra array to one image column (refcube is an image with spectral information in the 3rd dim.)
285 im_col = spectra.reshape((spectra.shape[0], 1, spectra.shape[1]))
287 meta = self.data.metadata # needs to be copied to the new GeoArray
289 if self.data.size:
290 # validation
291 if spectra.shape[0] != self.data.shape[0]:
292 raise ValueError('The number of signatures in the given spectra array does not match the dimensions of '
293 'the reference cube.')
295 # append spectra to existing reference cube
296 new_cube = np.hstack([self.data, im_col])
297 self.data = GeoArray(new_cube, nodata=self.data.nodata)
299 else:
300 self.data = GeoArray(im_col, nodata=self.data.nodata)
302 # copy previous metadata to the new GeoArray instance
303 self.data.metadata = meta
305 # add source image name to list of image names
306 self.srcImNames.append(src_imname)
308 @property
309 def metadata(self):
310 """Return an ordered dictionary holding the metadata of the reference cube."""
311 attrs2include = ['satellite', 'sensor', 'filepath', 'n_signatures', 'n_images', 'n_clusters',
312 'n_signatures_per_cluster', 'col_imName_dict', 'LayerBandsAssignment', 'wavelengths']
313 return OrderedDict((k, getattr(self, k)) for k in attrs2include)
315 def get_band_combination(self, tgt_LBA):
316 # type: (List[str]) -> GeoArray
317 """Get an array according to the bands order given by a target LayerBandsAssignment.
319 :param tgt_LBA: target LayerBandsAssignment
320 :return:
321 """
322 if tgt_LBA != self.LayerBandsAssignment:
323 cur_LBA_dict = dict(zip(self.LayerBandsAssignment, range(len(self.LayerBandsAssignment))))
324 tgt_bIdxList = [cur_LBA_dict[lr] for lr in tgt_LBA]
326 return GeoArray(np.take(self.data, tgt_bIdxList, axis=2), nodata=self.data.nodata)
327 else:
328 return self.data
330 def get_spectra_dataframe(self, tgt_LBA):
331 # type: (List[str]) -> DataFrame
332 """Return a pandas.DataFrame [sample x band] according to the given LayerBandsAssignment.
334 :param tgt_LBA: target LayerBandsAssignment
335 :return:
336 """
337 imdata = self.get_band_combination(tgt_LBA)
338 spectra = im2spectra(imdata)
339 df = DataFrame(spectra, columns=['B%s' % band for band in tgt_LBA])
341 return df
343 def rearrange_layers(self, tgt_LBA):
344 # type: (List[str]) -> None
345 """Rearrange the spectral bands of the reference cube according to the given LayerBandsAssignment.
347 :param tgt_LBA: target LayerBandsAssignment
348 """
349 self.data = self.get_band_combination(tgt_LBA)
350 self.LayerBandsAssignment = tgt_LBA
352 def read_data_from_disk(self, filepath):
353 self.data = GeoArray(filepath)
355 with open(os.path.splitext(filepath)[0] + '.meta', 'r') as metaF:
356 meta = json.load(metaF)
357 for k, v in meta.items():
358 if k in ['n_signatures', 'n_images', 'n_clusters', 'n_signatures_per_cluster']:
359 continue # skip pure getters
360 else:
361 setattr(self, k, v)
363 def save(self, path_out, fmt='ENVI'):
364 # type: (str, str) -> None
365 """Save the reference cube to disk.
367 :param path_out: output path on disk
368 :param fmt: output format as GDAL format code
369 :return:
370 """
371 self.filepath = self.filepath or path_out
372 self.data.save(out_path=path_out, fmt=fmt)
374 # save metadata as JSON file
375 with open(os.path.splitext(path_out)[0] + '.meta', 'w') as metaF:
376 json.dump(self.metadata.copy(), metaF, separators=(',', ': '), indent=4)
378 def _get_spectra_by_label_imname(self, cluster_label, image_basename, n_spectra2get=100, random_state=0):
379 cluster_start_pos_all = list(range(0, self.n_signatures, self.n_signatures_per_cluster))
380 cluster_start_pos = cluster_start_pos_all[cluster_label]
381 spectra = self.data[cluster_start_pos: cluster_start_pos + self.n_signatures_per_cluster,
382 self.srcImNames.index(image_basename)]
383 idxs_specIncl = np.random.RandomState(seed=random_state).choice(range(self.n_signatures_per_cluster),
384 n_spectra2get)
385 return spectra[idxs_specIncl, :]
387 def plot_sample_spectra(self, image_basename, cluster_label='all', include_mean_spectrum=True,
388 include_median_spectrum=True, ncols=5, **kw_fig):
389 # type: (Union[str, int, List], str, bool, bool, int, dict) -> 'plt.figure'
390 from matplotlib import pyplot as plt
392 if isinstance(cluster_label, int):
393 lbls2plot = [cluster_label]
394 elif isinstance(cluster_label, list):
395 lbls2plot = cluster_label
396 elif cluster_label == 'all':
397 lbls2plot = list(range(self.n_clusters))
398 else:
399 raise ValueError(cluster_label)
401 # create a single plot
402 if len(lbls2plot) == 1:
403 if cluster_label == 'all':
404 cluster_label = 0
406 fig, axes = plt.figure(), None
407 spectra = self._get_spectra_by_label_imname(cluster_label, image_basename, 100)
408 for i in range(100):
409 plt.plot(self.wavelengths, spectra[i, :])
411 plt.xlabel('wavelength [nm]')
412 plt.ylabel('%s %s\nreflectance [0-10000]' % (self.satellite, self.sensor))
413 plt.title('Cluster #%s' % cluster_label)
415 if include_mean_spectrum:
416 plt.plot(self.wavelengths, np.mean(spectra, axis=0), c='black', lw=3)
417 if include_median_spectrum:
418 plt.plot(self.wavelengths, np.median(spectra, axis=0), '--', c='black', lw=3)
420 # create a plot with multiple subplots
421 else:
422 nplots = len(lbls2plot)
423 ncols = nplots if nplots < ncols else ncols
424 nrows = nplots // ncols if not nplots % ncols else nplots // ncols + 1
425 figsize = (4 * ncols, 3 * nrows)
426 fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, sharex='all', sharey='all',
427 **kw_fig)
429 for lbl, ax in tqdm(zip(lbls2plot, axes.flatten()), total=nplots):
430 spectra = self._get_spectra_by_label_imname(lbl, image_basename, 100)
432 for i in range(100):
433 ax.plot(self.wavelengths, spectra[i, :], lw=1)
435 if include_mean_spectrum:
436 ax.plot(self.wavelengths, np.mean(spectra, axis=0), c='black', lw=2)
437 if include_median_spectrum:
438 ax.plot(self.wavelengths, np.median(spectra, axis=0), '--', c='black', lw=3)
440 ax.grid(lw=0.2)
441 ax.set_ylim(0, 10000)
443 if ax.get_subplotspec().is_last_row():
444 ax.set_xlabel('wavelength [nm]')
445 if ax.get_subplotspec().is_first_col():
446 ax.set_ylabel('%s %s\nreflectance [0-10000]' % (self.satellite, self.sensor))
447 ax.set_title('Cluster #%s' % lbl)
449 fig.suptitle("Refcube spectra from image '%s':" % image_basename, fontsize=15)
450 plt.tight_layout(rect=(0, 0, 1, .95))
451 plt.show()
453 return fig