Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2 

3# spechomo, Spectral homogenization of multispectral satellite data 

4# 

5# Copyright (C) 2019-2021 

6# - Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de) 

7# - Helmholtz Centre Potsdam - GFZ German Research Centre for Geosciences Potsdam, 

8# Germany (https://www.gfz-potsdam.de/) 

9# 

10# This software was developed within the context of the GeoMultiSens project funded 

11# by the German Federal Ministry of Education and Research 

12# (project grant code: 01 IS 14 010 A-C). 

13# 

14# Licensed under the Apache License, Version 2.0 (the "License"); 

15# you may not use this file except in compliance with the License. 

16# You may obtain a copy of the License at 

17# 

18# http://www.apache.org/licenses/LICENSE-2.0 

19# 

20# Please note the following exception: `spechomo` depends on tqdm, which is 

21# distributed under the Mozilla Public Licence (MPL) v2.0 except for the files 

22# "tqdm/_tqdm.py", "setup.py", "README.rst", "MANIFEST.in" and ".gitignore". 

23# Details can be found here: https://github.com/tqdm/tqdm/blob/master/LICENCE. 

24# 

25# Unless required by applicable law or agreed to in writing, software 

26# distributed under the License is distributed on an "AS IS" BASIS, 

27# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

28# See the License for the specific language governing permissions and 

29# limitations under the License. 

30 

31"""Main module.""" 

32 

33import os 

34import logging # noqa F401 # flake8 issue 

35from typing import Union, Tuple # noqa F401 # flake8 issue 

36from multiprocessing import cpu_count 

37import traceback 

38import time 

39from tqdm import tqdm 

40import numpy as np 

41from geoarray import GeoArray # noqa F401 # flake8 issue 

42from specclassify import classify_image 

43# from specclassify import kNN_MinimumDistance_Classifier 

44 

45from .classifier import Cluster_Learner 

46from .exceptions import ClassifierNotAvailableError 

47from .logging import SpecHomo_Logger 

48from .options import options 

49from .utils import spectra2im, im2spectra 

50 

51 

52__author__ = 'Daniel Scheffler' 

53 

54_classifier_rootdir = options['classifiers']['rootdir'] 

55 

56 

57class SpectralHomogenizer(object): 

58 """Class for applying spectral homogenization by applying an interpolation or machine learning approach.""" 

59 

60 def __init__(self, classifier_rootDir='', logger=None, CPUs=None, progress=True): 

61 """Get instance of SpectralHomogenizer. 

62 

63 :param classifier_rootDir: root directory where machine learning classifiers are stored. 

64 :param logger: instance of logging.Logger 

65 :param progress: whether to show progress bars 

66 """ 

67 self.classifier_rootDir = classifier_rootDir or _classifier_rootdir 

68 self.logger = logger or SpecHomo_Logger(__name__) 

69 self.CPUs = CPUs or cpu_count() 

70 self.progress = progress 

71 

72 def interpolate_cube(self, arrcube, source_CWLs, target_CWLs, kind='linear'): 

73 # type: (Union[np.ndarray, GeoArray], list, list, str) -> GeoArray 

74 """Spectrally interpolate the spectral bands of a remote sensing image to new band positions. 

75 

76 :param arrcube: array to be spectrally interpolated 

77 :param source_CWLs: list of source central wavelength positions 

78 :param target_CWLs: list of target central wavelength positions 

79 :param kind: interpolation kind to be passed to scipy.interpolate.interp1d (default: 'linear') 

80 :return: 

81 """ 

82 from scipy.interpolate import interp1d # import here to avoid static TLS ImportError 

83 

84 assert kind in ['linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'], \ 

85 "%s is not a supported kind of spectral interpolation." % kind 

86 assert arrcube is not None,\ 

87 'L2B_obj.interpolate_cube_linear expects a numpy array as input. Got %s.' % type(arrcube) 

88 

89 orig_CWLs = np.array(source_CWLs) 

90 target_CWLs = np.array(target_CWLs) 

91 

92 self.logger.info( 

93 'Performing spectral homogenization (%s interpolation) with target wavelength positions at %s nm.' 

94 % (kind, ', '.join(np.round(np.array(target_CWLs[:-1]), 1).astype(str)) + 

95 ' and %s' % np.round(target_CWLs[-1], 1))) 

96 outarr = \ 

97 interp1d(np.array(orig_CWLs), 

98 arrcube, 

99 axis=2, 

100 kind=kind, 

101 fill_value='extrapolate')(target_CWLs) 

102 

103 if np.min(outarr) >= np.iinfo(np.int16).min and \ 

104 np.max(outarr) <= np.iinfo(np.int16).max: 

105 

106 outarr = outarr.astype(np.int16) 

107 

108 elif np.min(outarr) >= np.iinfo(np.int32).min and np.max(outarr) <= np.iinfo(np.int32).max: 

109 

110 outarr = outarr.astype(np.int32) 

111 

112 else: 

113 raise TypeError('The interpolated data cube cannot be cast into a 16- or 32-bit integer array.') 

114 

115 assert outarr.shape == tuple([*arrcube.shape[:2], len(target_CWLs)]) 

116 

117 return GeoArray(outarr) 

118 

119 def predict_by_machine_learner(self, arrcube, method, src_satellite, src_sensor, src_LBA, tgt_satellite, tgt_sensor, 

120 tgt_LBA, n_clusters=50, classif_alg='MinDist', kNN_n_neighbors=10, 

121 global_clf_threshold=options['classifiers']['prediction']['global_clf_threshold'], 

122 src_nodataVal=None, out_nodataVal=None, compute_errors=False, bandwise_errors=True, 

123 fallback_argskwargs=None): 

124 # type: (Union[np.ndarray, GeoArray], str, str, str, list, str, str, list, int, str, int, Union[str, int, float], int, int, bool, bool, dict) -> tuple # noqa 

125 """Predict spectral bands of target sensor by applying a machine learning approach. 

126 

127 NOTE: You may use the function spechomo.utils.list_available_transformations() to get a list of available 

128 transformations. You may also copy the input parameters for this method from the output there. 

129 

130 :param arrcube: input image array for target sensor spectral band prediction (rows x cols x bands) 

131 :param method: machine learning approach to be used for spectral bands prediction 

132 'LR': Linear Regression 

133 'RR': Ridge Regression 

134 'QR': Quadratic Regression 

135 'RFR': Random Forest Regression (50 trees; does not allow spectral sub-clustering) 

136 :param src_satellite: source satellite, e.g., 'Landsat-8' 

137 :param src_sensor: source sensor, e.g., 'OLI_TIRS' 

138 :param src_LBA: source LayerBandsAssignment # TODO document this 

139 :param tgt_satellite: target satellite, e.g., 'Landsat-8' 

140 :param tgt_sensor: target sensor, e.g., 'OLI_TIRS' 

141 :param tgt_LBA: target LayerBandsAssignment # TODO document this 

142 :param n_clusters: Number of spectral clusters to be used during LR/ RR/ QR homogenization. 

143 E.g., 50 means that the image to be converted to the spectral target sensor 

144 is clustered into 50 spectral clusters and one separate machine learner per 

145 cluster is applied to the input data to predict the homogenized image. If 

146 'spechomo_n_clusters' is set to 1, the source image is not clustered and 

147 only one machine learning classifier is used for prediction. 

148 :param classif_alg: Multispectral classification algorithm to be used to determine the spectral cluster 

149 each pixel belongs to. 

150 'MinDist': Minimum Distance (Nearest Centroid) 

151 'kNN': k-nearest-neighbour 

152 'kNN_MinDist': k-nearest-neighbour Minimum Distance (Nearest Centroid) 

153 'SAM': spectral angle mapping 

154 'kNN_SAM': k-nearest-neighbour spectral angle mapping 

155 'SID': spectral information divergence 

156 'FEDSA': fused euclidian distance / spectral angle 

157 'kNN_FEDSA': k-nearest-neighbour fused euclidian distance / spectral angle 

158 :param kNN_n_neighbors: The number of neighbors to be considered in case 'classif_alg' is set to 'kNN'. 

159 Otherwise, this parameter is ignored. 

160 :param global_clf_threshold: If given, all pixels where the computed similarity metric (set by 'classif_alg') 

161 exceeds the given threshold are predicted using the global classifier (based on a 

162 single transformation per band). 

163 - only usable for 'MinDist', 'SAM' and 'SID' as well as their kNN variants 

164 - may be given as float, integer or string to label a certain distance percentile 

165 - if given as string, it must match the format, e.g., '10%' for labelling the 

166 worst 10 % of the distances as unclassified 

167 :param src_nodataVal: no data value of source image (arrcube) 

168 - if no nodata value is set, it is tried to be auto-computed from arrcube 

169 :param out_nodataVal: no data value of predicted image 

170 :param compute_errors: whether to compute pixel- / bandwise model errors for estimated pixel values 

171 (default: false) 

172 :param bandwise_errors whether to compute error information for each band separately (True - default) 

173 or to average errors over bands using median (False) (ignored in case of fallback) 

174 :param fallback_argskwargs: arguments and keyword arguments to be passed to the fallback algorithm 

175 SpectralHomogenizer.interpolate_cube() in case harmonization fails 

176 :return: predicted array (rows x columns x bands) 

177 :rtype: Tuple[np.ndarray, Union[np.ndarray, None]] 

178 """ 

179 # TODO: add LBA validation to .predict() 

180 kw = dict(method=method, 

181 classifier_rootDir=self.classifier_rootDir, 

182 n_clusters=n_clusters, 

183 classif_alg=classif_alg, 

184 CPUs=self.CPUs, 

185 progress=self.progress, 

186 logger=self.logger) 

187 

188 if classif_alg.startswith('kNN'): 

189 kw['n_neighbors'] = kNN_n_neighbors 

190 

191 RSI_CP = RSImage_ClusterPredictor(**kw) 

192 

193 ###################### 

194 # get the classifier # 

195 ###################### 

196 

197 cls = None 

198 exc = Exception() 

199 try: 

200 cls = RSI_CP.get_classifier(src_satellite, src_sensor, src_LBA, tgt_satellite, tgt_sensor, tgt_LBA) 

201 

202 except FileNotFoundError as e: 

203 self.logger.warning('No machine learning classifier available that fulfills the specifications of the ' 

204 'spectral reference sensor. Falling back to linear interpolation for performing ' 

205 'spectral homogenization.') 

206 exc = e 

207 

208 except ClassifierNotAvailableError as e: 

209 self.logger.error('\nAn error occurred during spectral homogenization using the %s classifier. ' 

210 'Falling back to linear interpolation. Error message was: ' % method) 

211 self.logger.error(traceback.format_exc()) 

212 exc = e 

213 

214 ################## 

215 # run prediction # 

216 ################## 

217 

218 errors = None 

219 if cls: 

220 self.logger.info('Performing spectral homogenization using %s. Target is %s %s %s.' 

221 % (method, tgt_satellite, tgt_sensor, tgt_LBA)) 

222 cmap_nodataVal = src_nodataVal if src_nodataVal is not None else -9999 

223 

224 im_homo = RSI_CP.predict(arrcube, 

225 classifier=cls, 

226 in_nodataVal=src_nodataVal, 

227 cmap_nodataVal=cmap_nodataVal, 

228 out_nodataVal=out_nodataVal, 

229 global_clf_threshold=global_clf_threshold) # type: GeoArray 

230 

231 if compute_errors: 

232 errors = RSI_CP.compute_prediction_errors(im_homo, cls, 

233 nodataVal=src_nodataVal, 

234 cmap_nodataVal=cmap_nodataVal) 

235 

236 if not bandwise_errors: 

237 errors = np.median(errors, axis=2).astype(errors.dtype) 

238 

239 elif fallback_argskwargs: 

240 # fallback: use linear interpolation and set errors to an array of zeros 

241 im_homo = self.interpolate_cube(**fallback_argskwargs) # type: GeoArray 

242 

243 if compute_errors: 

244 self.logger.warning("Spectral homogenization algorithm had to be performed by linear interpolation " 

245 "(fallback). Unable to compute any accuracy information from that.") 

246 if bandwise_errors: 

247 errors = np.zeros_like(im_homo, dtype=np.int16) 

248 else: 

249 errors = np.zeros(im_homo.shape[:2], dtype=np.int16) 

250 

251 else: 

252 raise exc 

253 

254 # add metadata 

255 im_homo.metadata.band_meta['wavelength'] = cls.tgt_wavelengths if cls else fallback_argskwargs['target_CWLs'] 

256 im_homo.classif_map = RSI_CP.classif_map 

257 im_homo.distance_metrics = RSI_CP.distance_metrics 

258 

259 # handle negative values in the predicted image => set these pixels to nodata 

260 # im_homo = set_negVals_to_nodata(im_homo, out_nodataVal) 

261 

262 return im_homo, errors 

263 

264 

265class RSImage_ClusterPredictor(object): 

266 """Predictor class applying the predict() function of a machine learning classifier described by the given args.""" 

267 

268 def __init__(self, method='LR', n_clusters=50, classif_alg='MinDist', classifier_rootDir='', 

269 CPUs=1, logger=None, progress=True, **kw_clf_init): 

270 # type: (str, int, str, str, Union[None, int], logging.Logger, bool, dict) -> None 

271 """Get an instance of RSImage_ClusterPredictor. 

272 

273 :param method: machine learning approach to be used for spectral bands prediction 

274 'LR': Linear Regression 

275 'RR': Ridge Regression 

276 'QR': Quadratic Regression 

277 'RFR': Random Forest Regression (50 trees; does not allow spectral sub-clustering) 

278 :param n_clusters: Number of spectral clusters to be used during LR/ RR/ QR homogenization. 

279 E.g., 50 means that the image to be converted to the spectral target sensor 

280 is clustered into 50 spectral clusters and one separate machine learner per 

281 cluster is applied to the input data to predict the homogenized image. If 

282 'n_clusters' is set to 1, the source image is not clustered and 

283 only one machine learning classifier is used for prediction. 

284 :param classif_alg: algorithm to be used for image classification 

285 (to define which cluster each pixel belongs to) 

286 'MinDist': Minimum Distance (Nearest Centroid) 

287 'kNN': k-nearest-neighbour 

288 'kNN_MinDist': k-nearest-neighbour Minimum Distance (Nearest Centroid) 

289 'SAM': spectral angle mapping 

290 'kNN_SAM': k-nearest-neighbour spectral angle mapping 

291 'SID': spectral information divergence 

292 'FEDSA': fused euclidian distance / spectral angle 

293 'kNN_FEDSA': k-nearest-neighbour fused euclidian distance / spectral angle 

294 :param classifier_rootDir: root directory where machine learning classifiers are stored. 

295 :param CPUs: number of CPUs to use (default: 1) 

296 :param progress: whether to show progress bars 

297 :param logger: instance of logging.Logger() 

298 :param kw_clf_init keyword arguments to be passed to classifier init functions if possible, 

299 e.g., 'n_neighbours' sets the number of neighbours to be considered in kNN 

300 classification algorithms (set by 'classif_alg') 

301 """ 

302 self.method = method 

303 self.n_clusters = n_clusters 

304 self.classifier_rootDir = os.path.abspath(classifier_rootDir) if classifier_rootDir else _classifier_rootdir 

305 self.classif_map = None 

306 self.classif_map_fractions = None 

307 self.distance_metrics = None 

308 self.CPUs = CPUs or cpu_count() 

309 self.classif_alg = classif_alg 

310 self.logger = logger or SpecHomo_Logger(__name__) # must be pickable 

311 self.progress = progress 

312 self.kw_clf_init = kw_clf_init 

313 

314 # validate 

315 if method == 'RFR' and n_clusters > 1: 

316 self.logger.warning("The spectral homogenization method 'Random Forest Regression' does not allow spectral " 

317 "sub-clustering. Setting 'n_clusters' to 1.") 

318 self.n_clusters = 1 

319 

320 if self.classif_alg.startswith('kNN') and \ 

321 'n_neighbors' in kw_clf_init and \ 

322 self.n_clusters < kw_clf_init['n_neighbors']: 

323 self.kw_clf_init['n_neighbors'] = self.n_clusters 

324 

325 def get_classifier(self, src_satellite, src_sensor, src_LBA, tgt_satellite, tgt_sensor, tgt_LBA): 

326 # type: (str, str, list, str, str, list) -> Cluster_Learner 

327 """Select the correct machine learning classifier out of previously saved classifier collections. 

328 

329 Describe the classifier specifications with the given arguments. 

330 :param src_satellite: source satellite, e.g., 'Landsat-8' 

331 :param src_sensor: source sensor, e.g., 'OLI_TIRS' 

332 :param src_LBA: source LayerBandsAssignment 

333 :param tgt_satellite: target satellite, e.g., 'Landsat-8' 

334 :param tgt_sensor: target sensor, e.g., 'OLI_TIRS' 

335 :param tgt_LBA: target LayerBandsAssignment 

336 :return: classifier instance loaded from disk 

337 """ 

338 args_fd = (self.classifier_rootDir, self.method, self.n_clusters, 

339 src_satellite, src_sensor, src_LBA, tgt_satellite, tgt_sensor, tgt_LBA) 

340 

341 try: 

342 CL = Cluster_Learner.from_disk(*args_fd) 

343 

344 except FileNotFoundError: 

345 if self.classifier_rootDir == _classifier_rootdir: 

346 # the default root directory is used 

347 

348 if not os.path.exists(os.path.join(_classifier_rootdir, '%s_classifiers.zip' % self.method)): 

349 # download the classifiers 

350 self.logger.info('The pre-trained classifiers have not been downloaded yet. Downloading...') 

351 

352 from .utils import download_pretrained_classifiers 

353 download_pretrained_classifiers(method=self.method, 

354 tgt_dir=self.classifier_rootDir) 

355 

356 else: 

357 self.logger.error('%s classifiers found at %s. However, they do not contain a suitable classifier ' 

358 'for the current predition. If desired, delete the existing classifiers and try ' 

359 'again. Pre-trained classifiers are then automatically downloaded.' 

360 % (self.method, self.classifier_rootDir)) 

361 

362 # try again 

363 CL = Cluster_Learner.from_disk(*args_fd) 

364 

365 else: 

366 # classifier not found in the user provided root directory 

367 raise 

368 

369 return CL 

370 

371 def predict(self, image, classifier, in_nodataVal=None, out_nodataVal=None, cmap_nodataVal=-9999, 

372 global_clf_threshold=None, unclassified_pixVal=-1): 

373 # type: (Union[np.ndarray, GeoArray], Cluster_Learner, float, float, float, Union[str, int, float], int) -> GeoArray # noqa 

374 """Apply the prediction function of the given specifier to the given remote sensing image. 

375 

376 :param image: 3D array representing the input image 

377 :param classifier: the classifier instance 

378 :param in_nodataVal: no data value of the input image 

379 (auto-computed if not given or contained in image GeoArray) 

380 :param out_nodataVal: no data value written into the predicted image 

381 (copied from the input image if not given) 

382 :param cmap_nodataVal: no data value for the classification map 

383 in case more than one sub-classes are used for prediction (default: -9999) 

384 :param global_clf_threshold: If given, all pixels where the computed similarity metric (set by 'classif_alg') 

385 exceeds the given threshold are predicted using the global classifier (based on a 

386 single transformation per band). 

387 - not usable for 'kNN' 

388 - may be given as float, integer or string to label a certain distance percentile 

389 - if given as string, it must match the format, e.g., '10%' for labelling the 

390 worst 10 % of the distances as unclassified 

391 :param unclassified_pixVal: pixel value to be used in the classification map for unclassified pixels 

392 (default: -1) 

393 :return: 3D array representing the predicted spectral image cube 

394 """ 

395 image = image if isinstance(image, GeoArray) else GeoArray(image, nodata=in_nodataVal) 

396 

397 # ensure image.nodata is present (important for classify_image() -> overwrites cmap at nodata positions) 

398 image.nodata = in_nodataVal if in_nodataVal is not None else image.nodata # might be auto-computed here 

399 in_nodataVal = image.nodata 

400 

401 ########################## 

402 # get classification map # 

403 ########################## 

404 

405 # assign each input pixel to a cluster (compute classification with cluster centers as endmembers) 

406 if self.classif_map is None: 

407 if self.n_clusters > 1: 

408 self.logger.info('Assigning material-specific regressors to each image pixel.') 

409 

410 t0 = time.time() 

411 kw_clf = dict(classif_alg=self.classif_alg, 

412 in_nodataVal=image.nodata, 

413 cmap_nodataVal=cmap_nodataVal, # written into classif_map at nodata 

414 CPUs=self.CPUs, 

415 return_distance=True, 

416 **self.kw_clf_init) 

417 

418 if self.classif_alg in ['MinDist', 'kNN_MinDist', 'SAM', 'kNN_SAM', 'SID', 'FEDSA', 'kNN_FEDSA']: 

419 kw_clf.update(dict(unclassified_threshold=global_clf_threshold, 

420 unclassified_pixVal=unclassified_pixVal)) 

421 

422 if self.classif_alg == 'RF': 

423 train_spectra = np.vstack([classifier.MLdict[clust].cluster_sample_spectra 

424 for clust in range(classifier.n_clusters)]) 

425 train_labels = list(np.hstack([[i] * 100 

426 for i in range(classifier.n_clusters)])) 

427 else: 

428 train_spectra = classifier.cluster_centers 

429 train_labels = classifier.cluster_pixVals 

430 

431 # run classification 

432 # - uses 3 neighbors by default in case of kNN classifiers 

433 self.classif_map, self.distance_metrics = classify_image(image, train_spectra, train_labels, **kw_clf) 

434 

435 # compute spectral distance 

436 # dist = kNN_MinimumDistance_Classifier.compute_euclidian_distance_3D(image, train_spectra) 

437 # idxs = self.classif_map.reshape(-1, self.classif_map.shape[2]) 

438 # self.distance_metrics = \ 

439 # dist.reshape(-1, dist.shape[2])[np.arange(dist.shape[0] * dist.shape[1])[:, np.newaxis], idxs] \ 

440 # .reshape(self.classif_map.shape) 

441 # print('ED MAX MIN:', self.distance_metrics.max(), self.distance_metrics.min()) 

442 

443 self.logger.info('Total classification time: %s' 

444 % time.strftime("%H:%M:%S", time.gmtime(time.time() - t0))) 

445 

446 else: 

447 self.classif_map = GeoArray(np.full((image.rows, 

448 image.cols), 

449 classifier.cluster_pixVals[0], 

450 np.int16), 

451 nodata=cmap_nodataVal) 

452 

453 # overwrite all pixels where the input image contains nodata in ANY band 

454 # (would lead to faulty predictions due to multivariate prediction algorithms) 

455 if in_nodataVal is not None and cmap_nodataVal is not None: 

456 self.classif_map[np.any(image[:] == image.nodata, axis=2)] = cmap_nodataVal 

457 

458 self.distance_metrics = np.zeros_like(self.classif_map, 

459 np.float32) 

460 

461 ############################## 

462 # compute prediction weights # 

463 ############################## 

464 

465 # compute the weights (only needed in case of multiple kNN classifiers) 

466 if classifier.n_clusters > 1 and\ 

467 self.classif_map.ndim > 2: 

468 

469 self.logger.info('Computing prediction weights per pixel for each regressor.') 

470 

471 if self.classif_alg == 'kNN_SAM': 

472 # scale SAM values between 0 and 15 degrees spectral angle 

473 dist_min, dist_max = 0, 15 

474 else: 

475 if in_nodataVal is not None: 

476 # exclude distances where cmap contains nodata (-9999) or unclassified (-1) values 

477 dists4stats = self.distance_metrics[self.classif_map[:, :, 0] > 0] 

478 else: 

479 dists4stats = self.distance_metrics 

480 

481 dist_min, dist_max = np.min(dists4stats), np.percentile(dists4stats, 90) 

482 

483 dist_norm = (self.distance_metrics - dist_min) /\ 

484 (dist_max - dist_min) 

485 weights = 1 - dist_norm 

486 weights[weights < 0] = 1e-10 # set negative weights to 0 but avoid ZeroDivisionError 

487 

488 else: 

489 weights = None 

490 

491 # weights = None if self.classif_map.ndim == 2 else \ 

492 # 1 - (self.distance_metrics / np.sum(self.distance_metrics, axis=2, keepdims=True)) 

493 

494 # if self.classif_map.ndim > 2: 

495 # print(self.distance_metrics[0, 0, :]) 

496 # print(weights[0, 0, :]) 

497 

498 #################### 

499 # apply prediction # 

500 #################### 

501 

502 self.logger.info(f'Starting prediction with {self.method} regressor, {self.n_clusters} clusters, ' 

503 f'{self.classif_alg}.') 

504 

505 # adjust classifier for multiprocessing 

506 if self.CPUs is None or self.CPUs > 1: 

507 # FIXME does not work -> parallelize with https://github.com/ajtulloch/sklearn-compiledtrees? 

508 classifier.n_jobs = cpu_count() if self.CPUs is None else self.CPUs 

509 

510 # get an empty GeoArray for the prediction result 

511 t0 = time.time() 

512 out_nodataVal = out_nodataVal if out_nodataVal is not None else image.nodata 

513 image_predicted = GeoArray(np.empty((image.rows, 

514 image.cols, 

515 classifier.tgt_n_bands), 

516 dtype=image.dtype), 

517 geotransform=image.gt, 

518 projection=image.prj, 

519 nodata=out_nodataVal, 

520 bandnames=['B%s' % i 

521 if len(i) == 2 

522 else 'B0%s' % i 

523 for i in classifier.tgt_LBA]) 

524 

525 # set image_predicted to nodata at nodata positions of the input image 

526 if out_nodataVal is not None: 

527 image_predicted[~image.mask_nodata[:]] = out_nodataVal 

528 

529 # NOTE: 

530 # - prediction now only runs on the remaining pixels (that contain data) 

531 # - computation is running in chunks of 50,000 spectra to save memory 

532 # (classifier.predict returns float32) and speed up processing 

533 # ---------------------------------------------------------------------- 

534 

535 # get all spectra at pixels that really contain data and 

536 # reshape them to represent a single image column 

537 # (classifier.predict expects a 3D image-like input array) 

538 spectra_at_datapos = image[image.mask_nodata[:]] 

539 n_spectra = spectra_at_datapos.shape[0] 

540 spectra_as_im = GeoArray(spectra2im(spectra_at_datapos, n_spectra, 1)) 

541 

542 # get the corresponding weights and classification maps (also as one image column) 

543 if weights is not None: 

544 # in case of kNN classifiers 

545 weights_datapos = spectra2im(weights[image.mask_nodata[:]], n_spectra, 1) 

546 classif_map_datapos = spectra2im(self.classif_map[image.mask_nodata[:]], n_spectra, 1) 

547 else: 

548 # in case no kNN classifier was used and we don't have to respect any weights 

549 weights_datapos = None 

550 classif_map_datapos = self.classif_map[image.mask_nodata[:]].reshape(-1, 1) 

551 

552 # spectra_predicted will be filled while looping over chunks 

553 spectra_predicted = np.empty((n_spectra, image_predicted.bands), image_predicted.dtype) 

554 n_saturated_px = 0 

555 

556 for ((rS, rE), (cS, cE)), im_tile in tqdm(spectra_as_im.tiles(tilesize=(50000, 1)), 

557 desc='Predicting in chunks', 

558 disable=not self.progress): 

559 

560 classif_map_tile = classif_map_datapos[rS: rE + 1, cS: cE + 1] # integer array 

561 

562 # predict! 

563 if self.classif_map.ndim == 2: 

564 im_tile_pred = \ 

565 classifier.predict(im_tile, 

566 classif_map_tile, 

567 nodataVal=out_nodataVal, 

568 cmap_nodataVal=cmap_nodataVal, 

569 cmap_unclassifiedVal=unclassified_pixVal) 

570 

571 else: 

572 weights_tile = weights_datapos[rS: rE + 1, cS: cE + 1] # float array 

573 

574 im_tile_pred = \ 

575 classifier.predict_weighted_averages(im_tile, 

576 classif_map_tile, 

577 weights_tile, 

578 nodataVal=out_nodataVal, 

579 cmap_nodataVal=cmap_nodataVal, 

580 cmap_unclassifiedVal=unclassified_pixVal) 

581 

582 # set saturated pixels (exceeding the output data range with respect to the data type) to no-data 

583 # NOTE: this is computed on the chunks to save memory 

584 if isinstance(image_predicted.dtype, np.integer): 

585 out_dTMin, out_dTMax = np.iinfo(image_predicted.dtype).min,\ 

586 np.iinfo(image_predicted.dtype).max 

587 

588 if np.min(im_tile_pred) < out_dTMin or\ 

589 np.max(im_tile_pred) > out_dTMax: 

590 

591 mask_saturated = np.any(im_tile_pred > out_dTMax | 

592 im_tile_pred < out_dTMin, 

593 axis=2) 

594 n_saturated_px += np.sum(mask_saturated) 

595 im_tile_pred[mask_saturated] = out_nodataVal 

596 

597 spectra_predicted[rS:rE + 1, :] = im2spectra(im_tile_pred) # [n_spectra x n_tgt_bands] 

598 

599 # fill in the predicted spectra 

600 image_predicted[image.mask_nodata[:]] = spectra_predicted 

601 

602 if n_saturated_px: 

603 self.logger.warning("%.2f %% of the predicted pixels are saturated and set to no-data." 

604 % n_saturated_px / np.dot(*image_predicted.shape[:2]) * 100) 

605 

606 self.logger.info('Total prediction time: %s' % time.strftime("%H:%M:%S", time.gmtime(time.time()-t0))) 

607 

608 ############################### 

609 # complete prediction results # 

610 ############################### 

611 

612 # re-apply nodata values to predicted result 

613 if image.nodata is not None: 

614 mask_nodata = image.calc_mask_nodata(overwrite=True, flag='any') 

615 image_predicted[~mask_nodata] = out_nodataVal 

616 

617 # copy mask_nodata 

618 image_predicted.mask_nodata = image.mask_nodata 

619 

620 # append weights to predicted image 

621 image_predicted.weights = weights 

622 

623 # image_predicted.save( 

624 # '/home/gfz-fe/scheffler/temp/SPECHOM_py/image_predicted_QRclust1_MinDist_noB9.bsq') 

625 # GeoArray(self.classif_map).save( 

626 # '/home/gfz-fe/scheffler/temp/SPECHOM_py/classif_map_QRclust1_MinDist_noB9.bsq') 

627 

628 # append some statistics regarding the homogenization 

629 cmap_vals, cmap_valcounts = np.unique(self.classif_map, return_counts=True) 

630 cmap_valfractions = cmap_valcounts / self.classif_map.size 

631 self.classif_map_fractions = dict(zip(list(cmap_vals), list(cmap_valfractions))) 

632 

633 # log the pixel fraction where material-specific regressors were applied 

634 frac = self.classif_map_fractions 

635 if -1 in frac: 

636 glob_regr_perc = frac[-1] * 100 

637 nodata_perc = frac[cmap_nodataVal] * 100 if cmap_nodataVal in frac else 0 

638 data_perc = 100 - nodata_perc 

639 opt_regr_perc = 100 - glob_regr_perc - nodata_perc 

640 self.logger.info(f"No-data fraction:\t{nodata_perc:.1f}%") 

641 self.logger.info(f"Regressor fractions:\t" 

642 f"{opt_regr_perc / data_perc * 100:.1f}% material optimized; " 

643 f"{glob_regr_perc / data_perc * 100:.1f}% global regressor") 

644 

645 return image_predicted 

646 

647 def compute_prediction_errors(self, im_predicted, cluster_classifier, nodataVal=None, cmap_nodataVal=None): 

648 # type: (Union[np.ndarray, GeoArray], Cluster_Learner, float, float) -> np.ndarray 

649 """Compute errors that quantify prediction inaccurracy per band and per pixel. 

650 

651 :param im_predicted: 3D array representing the predicted image 

652 :param cluster_classifier: instance of Cluster_Learner 

653 :param nodataVal: no data value of the input image 

654 (auto-computed if not given or contained in im_predicted GeoArray) 

655 NOTE: The value is also used as output nodata value for the errors array. 

656 :param cmap_nodataVal: no data value for the classification map 

657 in case more than one sub-classes are used for prediction 

658 :return: 3D array (int16) representing prediction errors per band and pixel 

659 """ 

660 im_predicted = im_predicted if isinstance(im_predicted, GeoArray) else GeoArray(im_predicted, nodata=nodataVal) 

661 im_predicted.nodata = nodataVal if nodataVal is not None else im_predicted.nodata # might be auto-computed here 

662 

663 for clf in cluster_classifier: 

664 if not len(clf.rmse_per_band) == GeoArray(im_predicted).bands: 

665 raise ValueError('The given classifier contains error statistics incompatible to the shape of the ' 

666 'image.') 

667 if self.classif_map is None: 

668 raise RuntimeError('self.classif_map must be generated by running self.predict() beforehand.') 

669 

670 if self.classif_map.ndim == 3: 

671 # FIXME: error computation does not work for kNN algorithms so far (self.classif_map is 3D instead of 2D) 

672 raise NotImplementedError('Error computation for 3-dimensional classification maps (e.g., due to kNN ' 

673 'classification algorithms) is not yet implemented.') 

674 

675 errors = np.empty_like(im_predicted) 

676 

677 # iterate over all cluster labels and copy rmse values 

678 for pixVal in sorted(list(np.unique(self.classif_map))): 

679 if pixVal == cmap_nodataVal: 

680 continue 

681 

682 self.logger.info('Inpainting error values for cluster #%s...' % pixVal) 

683 

684 clf2use = cluster_classifier.MLdict[pixVal] if pixVal != -1 else cluster_classifier.global_clf 

685 rmse_per_band_int = np.round(clf2use.rmse_per_band, 0).astype(np.int16) 

686 errors[self.classif_map[:] == pixVal] = rmse_per_band_int 

687 

688 # TODO validate this equation 

689 # errors = (errors * im_predicted[:] / 10000).astype(errors.dtype) 

690 

691 # re-apply nodata values to predicted result 

692 if im_predicted.nodata is not None: 

693 # errors[im_predicted == im_predicted.nodata] = im_predicted.nodata 

694 errors[im_predicted.mask_nodata.astype(np.int8) == 0] = im_predicted.nodata 

695 

696 # GeoArray(errors).save('/home/gfz-fe/scheffler/temp/SPECHOM_py/errors_LRclust1_MinDist_noB9_clusterpred.bsq') 

697 

698 return errors