Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2 

3# spechomo, Spectral homogenization of multispectral satellite data 

4# 

5# Copyright (C) 2019-2021 

6# - Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de) 

7# - Helmholtz Centre Potsdam - GFZ German Research Centre for Geosciences Potsdam, 

8# Germany (https://www.gfz-potsdam.de/) 

9# 

10# This software was developed within the context of the GeoMultiSens project funded 

11# by the German Federal Ministry of Education and Research 

12# (project grant code: 01 IS 14 010 A-C). 

13# 

14# Licensed under the Apache License, Version 2.0 (the "License"); 

15# you may not use this file except in compliance with the License. 

16# You may obtain a copy of the License at 

17# 

18# http://www.apache.org/licenses/LICENSE-2.0 

19# 

20# Please note the following exception: `spechomo` depends on tqdm, which is 

21# distributed under the Mozilla Public Licence (MPL) v2.0 except for the files 

22# "tqdm/_tqdm.py", "setup.py", "README.rst", "MANIFEST.in" and ".gitignore". 

23# Details can be found here: https://github.com/tqdm/tqdm/blob/master/LICENCE. 

24# 

25# Unless required by applicable law or agreed to in writing, software 

26# distributed under the License is distributed on an "AS IS" BASIS, 

27# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

28# See the License for the specific language governing permissions and 

29# limitations under the License. 

30 

31import os 

32import tempfile 

33import zipfile 

34from collections import OrderedDict 

35from pprint import pformat 

36from typing import Union, List, TYPE_CHECKING # noqa F401 # flake8 issue 

37import json 

38import builtins 

39 

40if TYPE_CHECKING: 

41 from matplotlib import pyplot as plt # noqa F401 # flake8 issue 

42 

43from tqdm import tqdm 

44import dill 

45import numpy as np 

46from pandas import DataFrame 

47from geoarray import GeoArray # noqa F401 # flake8 issue 

48 

49from .classifier_creation import get_filename_classifier_collection, get_machine_learner 

50from .exceptions import ClassifierNotAvailableError 

51from .utils import im2spectra, spectra2im 

52 

53 

54class Cluster_Learner(object): 

55 """ 

56 A class that holds the machine learning classifiers for multiple spectral clusters as well as a global classifier. 

57 

58 These classifiers can be applied to an input sensor image by using the predict method. 

59 """ 

60 

61 def __init__(self, dict_clust_MLinstances, global_classifier): 

62 # type: (Union[dict, ClassifierCollection], any) -> None 

63 """Get an instance of Cluster_Learner. 

64 

65 :param dict_clust_MLinstances: a dictionary of cluster specific machine learning classifiers 

66 :param global_classifier: the global machine learning classifier to be applied at image positions with 

67 high spectral dissimilarity to the available cluster centers 

68 """ 

69 self.cluster_pixVals = list(sorted(dict_clust_MLinstances.keys())) # type: List[int] 

70 self.MLdict = OrderedDict((clust, dict_clust_MLinstances[clust]) for clust in self.cluster_pixVals) 

71 self.global_clf = global_classifier 

72 sample_MLinst = list(self.MLdict.values())[0] 

73 self.src_satellite = sample_MLinst.src_satellite 

74 self.src_sensor = sample_MLinst.src_sensor 

75 self.tgt_satellite = sample_MLinst.tgt_satellite 

76 self.tgt_sensor = sample_MLinst.tgt_sensor 

77 self.src_LBA = sample_MLinst.src_LBA 

78 self.tgt_LBA = sample_MLinst.tgt_LBA 

79 self.src_n_bands = sample_MLinst.src_n_bands 

80 self.tgt_n_bands = sample_MLinst.tgt_n_bands 

81 self.src_wavelengths = sample_MLinst.src_wavelengths 

82 self.tgt_wavelengths = sample_MLinst.tgt_wavelengths 

83 self.n_clusters = sample_MLinst.n_clusters 

84 self.cluster_centers = np.array([cc.cluster_center for cc in self.MLdict.values()]) 

85 self.spechomo_version = \ 

86 sample_MLinst.spechomo_version if hasattr(sample_MLinst, 'spechomo_version') else 'NA' 

87 self.spechomo_versionalias = \ 

88 sample_MLinst.spechomo_versionalias if hasattr(sample_MLinst, 'spechomo_versionalias') else 'NA' 

89 

90 @classmethod 

91 def from_disk(cls, classifier_rootDir, method, n_clusters, 

92 src_satellite, src_sensor, src_LBA, tgt_satellite, tgt_sensor, tgt_LBA, n_estimators=50): 

93 # type: (str, str, int, str, str, list, str, str, list, int) -> Cluster_Learner 

94 """Read a previously saved ClusterLearner from disk and return a ClusterLearner instance. 

95 

96 Describe the classifier specifications with the given arguments. 

97 

98 :param classifier_rootDir: root directory of the classifiers 

99 :param method: harmonization method 

100 'LR': Linear Regression 

101 'RR': Ridge Regression 

102 'QR': Quadratic Regression 

103 'RFR': Random Forest Regression (50 trees; does not allow spectral sub-clustering) 

104 :param n_clusters: number of clusters 

105 :param src_satellite: source satellite, e.g., 'Landsat-8' 

106 :param src_sensor: source sensor, e.g., 'OLI_TIRS' 

107 :param src_LBA: source LayerBandsAssignment 

108 :param tgt_satellite: target satellite, e.g., 'Landsat-8' 

109 :param tgt_sensor: target sensor, e.g., 'OLI_TIRS' 

110 :param tgt_LBA: target LayerBandsAssignment 

111 :param n_estimators: number of estimators (only used in case of method=='RFR' 

112 :return: classifier instance loaded from disk 

113 """ 

114 # get path of classifier zip archive 

115 args = (method, src_satellite, src_sensor, src_LBA, tgt_satellite, tgt_sensor, tgt_LBA) 

116 kw_clfinit = dict(n_estimators=n_estimators) 

117 

118 if os.path.isfile(os.path.join(classifier_rootDir, '%s_classifiers.zip' % method)): 

119 # get cluster specific classifiers and store them in a ClassifierCollection dictionary 

120 dict_clust_MLinstances = cls._read_ClassifierCollection_from_zipFile( 

121 classifier_rootDir, *args, n_clusters=n_clusters, **kw_clfinit) 

122 

123 # get the global classifier and add it as cluster label '-1' 

124 global_clf = cls._read_ClassifierCollection_from_zipFile( 

125 classifier_rootDir, *args, n_clusters=1, **kw_clfinit)[0] 

126 

127 elif os.path.isdir(classifier_rootDir): 

128 # get cluster specific classifiers and store them in a ClassifierCollection dictionary 

129 dict_clust_MLinstances = cls._read_ClassifierCollection_from_unzippedFile( 

130 classifier_rootDir, *args, n_clusters=n_clusters, **kw_clfinit) 

131 

132 # get the global classifier and add it as cluster label '-1' 

133 global_clf = cls._read_ClassifierCollection_from_unzippedFile( 

134 classifier_rootDir, *args, n_clusters=1, **kw_clfinit)[0] 

135 

136 else: 

137 raise FileNotFoundError("No '%s' classifiers available at %s." % (method, classifier_rootDir)) 

138 

139 # create an instance of ClusterLearner based on the ClassifierCollection dictionary 

140 return Cluster_Learner(dict_clust_MLinstances, global_clf) 

141 

142 @staticmethod 

143 def _read_ClassifierCollection_from_zipFile(classifier_rootDir, method, src_satellite, src_sensor, 

144 src_LBA, tgt_satellite, tgt_sensor, tgt_LBA, n_clusters, 

145 **kw_clfinit): 

146 # type: (str, str, str, str, list, str, str, list, int, dict) -> ClassifierCollection 

147 

148 path_classifier_zip = os.path.join(classifier_rootDir, '%s_classifiers.zip' % method) 

149 

150 # read requested classifier from zip archive and create a ClassifierCollection 

151 with zipfile.ZipFile(path_classifier_zip, "r") as zf, tempfile.TemporaryDirectory() as td: 

152 fName_clf = get_filename_classifier_collection(method, src_satellite, src_sensor, n_clusters=n_clusters, 

153 **kw_clfinit) 

154 try: 

155 zf.extract(fName_clf, td) 

156 except KeyError: 

157 raise FileNotFoundError("No classifiers for %s %s with %d clusters contained in %s." 

158 % (src_satellite, src_sensor, n_clusters, path_classifier_zip)) 

159 

160 return Cluster_Learner._read_ClassifierCollection_from_unzippedFile( 

161 td, method, src_satellite, src_sensor, src_LBA, 

162 tgt_satellite, tgt_sensor, tgt_LBA, n_clusters) 

163 

164 @staticmethod 

165 def _read_ClassifierCollection_from_unzippedFile(classifier_rootDir, method, src_satellite, src_sensor, src_LBA, 

166 tgt_satellite, tgt_sensor, tgt_LBA, n_clusters, **kw_clfinit): 

167 # type: (str, str, str, str, list, str, str, list, int, dict) -> ClassifierCollection 

168 

169 fName_clf_clustN = get_filename_classifier_collection(method, src_satellite, src_sensor, 

170 n_clusters=n_clusters, **kw_clfinit) 

171 path_classifier = os.path.join(classifier_rootDir, fName_clf_clustN) 

172 

173 # read requested classifier from zip archive and create a ClassifierCollection 

174 try: 

175 clf_collection = \ 

176 ClassifierCollection(path_classifier)['__'.join(src_LBA)][tgt_satellite, tgt_sensor]['__'.join(tgt_LBA)] 

177 except KeyError: 

178 raise ClassifierNotAvailableError(method, src_satellite, src_sensor, src_LBA, 

179 tgt_satellite, tgt_sensor, tgt_LBA, n_clusters) 

180 

181 # validation 

182 expected_MLtype = type(get_machine_learner(method)) 

183 if len(clf_collection.keys()) != n_clusters: 

184 raise RuntimeError('Read classifier with %s clusters instead of %s.' 

185 % (len(clf_collection.keys()), n_clusters)) 

186 for label, ml in clf_collection.items(): 

187 if not isinstance(ml, expected_MLtype): 

188 raise ValueError("The given dillFile %s contains a spectral cluster (label '%s') with a %s machine " 

189 "learner instead of the expected %s." 

190 % (os.path.basename(path_classifier), label, type(ml), expected_MLtype.__name__,)) 

191 

192 return clf_collection 

193 

194 def __iter__(self): 

195 for cluster in self.cluster_pixVals: 

196 yield self.MLdict[cluster] 

197 

198 def predict(self, im_src, cmap, nodataVal=None, cmap_nodataVal=None, cmap_unclassifiedVal=-1): 

199 # type: (Union[np.ndarray, GeoArray], np.ndarray, Union[int, float], Union[int, float], Union[int, float]) -> np.ndarray # noqa 

200 """Predict target satellite spectral information using separate prediction coefficients for spectral clusters. 

201 

202 :param im_src: input image to be used for prediction 

203 :param cmap: classification map that assigns each image spectrum to a corresponding cluster 

204 -> must be a 2D np.ndarray with the same X-/Y-dimension like im_src 

205 :param nodataVal: nodata value to be used to fill into the predicted image 

206 :param cmap_nodataVal: nodata class value of the nodata class of the classification map 

207 :param cmap_unclassifiedVal: 'unclassified' class value of the nodata class of the classification map 

208 :return: 

209 """ 

210 cluster_labels = sorted(list(np.unique(cmap))) 

211 

212 im_pred = np.full((im_src.shape[0], im_src.shape[1], self.tgt_n_bands), 

213 fill_value=nodataVal if nodataVal is not None else 0, 

214 dtype=np.float32) 

215 

216 if len(cluster_labels) > 1: 

217 # iterate over all cluster labels and apply corresponding machine learner parameters 

218 # to predict target spectra 

219 for pixVal in cluster_labels: 

220 if pixVal == cmap_nodataVal: 

221 # at nodata positions, the predicted pixel value will also be nodata 

222 # -> don't do anything because im_pred was already initialized as an array full of nodata values 

223 continue 

224 

225 elif pixVal == cmap_unclassifiedVal: 

226 # apply global homogenization coefficients 

227 classifier = self.global_clf 

228 

229 else: 

230 # apply cluster specific homogenization coefficients 

231 classifier = self.MLdict[pixVal] 

232 

233 mask_pixVal = cmap == pixVal 

234 im_pred[mask_pixVal] = classifier.predict(im_src[mask_pixVal]) 

235 

236 else: 

237 # predict target spectra directly (much faster than the above algorithm) 

238 pixVal = cluster_labels[0] 

239 

240 if pixVal == cmap_nodataVal: 

241 # im_src consists only of no data values 

242 pass # im_pred keeps at nodataVal 

243 

244 else: 

245 if pixVal == cmap_unclassifiedVal: 

246 # apply global homogenization coefficients 

247 classifier = self.global_clf 

248 else: 

249 classifier = self.MLdict[pixVal] 

250 assert classifier.clusterlabel == pixVal 

251 

252 spectra = im2spectra(im_src) 

253 spectra_pred = classifier.predict(spectra) 

254 im_pred = spectra2im(spectra_pred, im_src.shape[0], im_src.shape[1]) 

255 

256 return im_pred # float32 array 

257 

258 def predict_weighted_averages(self, im_src, cmap_3D, weights_3D=None, nodataVal=None, 

259 cmap_nodataVal=None, cmap_unclassifiedVal=-1): 

260 # type: (Union[np.ndarray, GeoArray], np.ndarray, np.ndarray, Union[int, float], Union[int, float], Union[int, float]) -> np.ndarray # noqa 

261 """Predict target satellite spectral information using separate prediction coefficients for spectral clusters. 

262 

263 NOTE: This version of the prediction function uses the prediction coefficients of multiple spectral clusters 

264 and computes the result as weighted average of them. Therefore, the classification map must assign 

265 multiple spectral clusters to each input pixel. 

266 

267 # NOTE: At unclassified pixels (cmap_3D[y,x,z>0] == -1) the prediction result using global coefficients 

268 # is ignored in the weighted average. In that case the prediction result is based on the found valid 

269 # spectral clusters and is not affected by the global coefficients (should improve prediction results). 

270 

271 :param im_src: input image to be used for prediction 

272 :param cmap_3D: classification map that assigns each image spectrum to multiple corresponding clusters 

273 -> must be a 3D np.ndarray with the same X-/Y-dimension like im_src 

274 :param weights_3D: 

275 :param nodataVal: nodata value to be used to fill into the predicted image 

276 :param cmap_nodataVal: nodata class value of the nodata class of the classification map 

277 :param cmap_unclassifiedVal: 'unclassified' class value of the nodata class of the classification map 

278 :return: 

279 """ 

280 if not cmap_3D.ndim > 2: 

281 raise ValueError('Input classification map needs at least 2 bands to compute prediction results as' 

282 'weighted averages.') 

283 

284 if cmap_3D.shape != weights_3D.shape: 

285 raise ValueError("The input arrays 'cmap_3D' and 'weights_3D' need to have the same dimensions. " 

286 "Received %s vs. %s." % (cmap_3D.shape, weights_3D.shape)) 

287 

288 # predict for each classification map band 

289 ims_pred_temp = [] 

290 

291 for band in range(cmap_3D.shape[2]): 

292 ims_pred_temp.append( 

293 self.predict(im_src, 

294 cmap_3D[:, :, band], 

295 nodataVal=nodataVal, 

296 cmap_nodataVal=cmap_nodataVal, 

297 cmap_unclassifiedVal=cmap_unclassifiedVal 

298 )) 

299 

300 # merge classification results by weighted averaging 

301 nsamp = np.dot(*weights_3D.shape[:2]) 

302 nbandpred = ims_pred_temp[0].shape[2] 

303 nbandscmap = weights_3D.shape[2] 

304 

305 weights = \ 

306 np.ones((nsamp, nbandpred, nbandscmap)) if weights_3D is None else \ 

307 np.tile(weights_3D.reshape((nsamp, 1, nbandscmap)), 

308 (1, nbandpred, 1)) # nclust x n_tgt_bands x n_cmap_bands 

309 

310 # set weighting of unclassified pixel positions to zero (except from the first cmap band) 

311 # -> see NOTE #2 in the docstring 

312 # mask_unclassif = np.tile(cmap_3D.reshape(nsamp, 1, nbandscmap), (1, nbandpred, 1)) == cmap_unclassifiedVal 

313 # mask_unclassif[:, :, :1] = False # if all other clusters are invalid, at least the first one is used for prediction # noqa 

314 # weights[mask_unclassif] = 0 

315 # FIXME this computes the prediction for all k-neighbors, no matter if the weights are 0 

316 spectra_pred = np.average(np.dstack([im2spectra(im) for im in ims_pred_temp]), 

317 weights=weights, 

318 axis=2) 

319 im_pred = spectra2im(spectra_pred, 

320 tgt_rows=im_src.shape[0], 

321 tgt_cols=im_src.shape[1]) 

322 

323 return im_pred 

324 

325 def plot_sample_spectra(self, cluster_label='all', include_mean_spectrum=True, include_median_spectrum=True, 

326 ncols=5, **kw_fig): 

327 # type: (Union[str, int, List], bool, bool, int, dict) -> plt.figure 

328 from matplotlib import pyplot as plt # noqa 

329 

330 if isinstance(cluster_label, int): 

331 lbls2plot = [cluster_label] 

332 elif isinstance(cluster_label, list): 

333 lbls2plot = cluster_label 

334 elif cluster_label == 'all': 

335 lbls2plot = list(range(self.n_clusters)) 

336 else: 

337 raise ValueError(cluster_label) 

338 

339 # create a single plot 

340 if len(lbls2plot) == 1: 

341 if cluster_label == 'all': 

342 cluster_label = 0 

343 

344 fig, axes = plt.figure(), None 

345 for i in range(100): 

346 plt.plot(self.src_wavelengths, self.MLdict[cluster_label].cluster_sample_spectra[i, :]) 

347 

348 plt.xlabel('wavelength [nm]') 

349 plt.ylabel('%s %s\nreflectance [0-10000]' % (self.src_satellite, self.src_sensor)) 

350 plt.title('Cluster #%s' % cluster_label) 

351 plt.grid(lw=0.2) 

352 plt.ylim(0, 10000) 

353 

354 if include_mean_spectrum: 

355 plt.plot(self.src_wavelengths, self.MLdict[cluster_label].cluster_center, c='black', lw=3) 

356 if include_median_spectrum: 

357 plt.plot(self.src_wavelengths, np.median(self.MLdict[cluster_label].cluster_sample_spectra, axis=0), 

358 '--', c='black', lw=3) 

359 

360 # create a plot with multiple subplots 

361 else: 

362 nplots = len(lbls2plot) 

363 ncols = nplots if nplots < ncols else ncols 

364 nrows = nplots // ncols if not nplots % ncols else nplots // ncols + 1 

365 figsize = (4 * ncols, 3 * nrows) 

366 fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, sharex='all', sharey='all', 

367 **kw_fig) 

368 

369 for lbl, ax in tqdm(zip(lbls2plot, axes.flatten()), total=nplots): 

370 for i in range(100): 

371 ax.plot(self.src_wavelengths, self.MLdict[lbl].cluster_sample_spectra[i, :], lw=1) 

372 

373 if include_mean_spectrum: 

374 ax.plot(self.src_wavelengths, self.MLdict[lbl].cluster_center, c='black', lw=2) 

375 if include_median_spectrum: 

376 ax.plot(self.src_wavelengths, np.median(self.MLdict[lbl].cluster_sample_spectra, axis=0), 

377 '--', c='black', lw=3) 

378 

379 ax.grid(lw=0.2) 

380 ax.set_ylim(0, 10000) 

381 

382 if ax.get_subplotspec().is_last_row(): 

383 ax.set_xlabel('wavelength [nm]') 

384 if ax.get_subplotspec().is_first_col(): 

385 ax.set_ylabel('%s %s\nreflectance [0-10000]' % (self.src_satellite, self.src_sensor)) 

386 ax.set_title('Cluster #%s' % lbl) 

387 

388 plt.tight_layout() 

389 plt.show() 

390 

391 return fig, axes 

392 

393 def _collect_stats(self, cluster_label): 

394 df = DataFrame(columns=['band', 'wavelength', 'RMSE', 'MAE', 'MAPE']) 

395 df.band = self.tgt_LBA 

396 df.wavelength = np.round(self.tgt_wavelengths, 1) 

397 df.RMSE = np.round(self.MLdict[cluster_label].rmse_per_band, 1) 

398 df.MAE = np.round(self.MLdict[cluster_label].mae_per_band, 1) 

399 df.MAPE = np.round(self.MLdict[cluster_label].mape_per_band, 1) 

400 

401 overall_stats = dict(scores=self.MLdict[cluster_label].scores) 

402 

403 return df, overall_stats 

404 

405 def print_stats(self): 

406 from tabulate import tabulate 

407 

408 for lbl in range(self.n_clusters): 

409 print('Cluster #%s:' % lbl) 

410 band_stats, overall_stats = self._collect_stats(lbl) 

411 print(overall_stats) 

412 print(tabulate(band_stats, headers=band_stats.columns)) 

413 print() 

414 

415 def to_jsonable_dict(self): 

416 """Create a dictionary containing a JSONable replicate of the current Cluster_Learner instance.""" 

417 common_meta_keys = ['src_satellite', 'src_sensor', 'tgt_satellite', 'tgt_sensor', 'src_LBA', 'tgt_LBA', 

418 'src_n_bands', 'tgt_n_bands', 'src_wavelengths', 'tgt_wavelengths', 'n_clusters', 

419 'spechomo_version', 'spechomo_versionalias'] 

420 jsonable_dict = dict() 

421 decode_types_dict = dict() 

422 

423 # get jsonable dict for global classifier and add decoding type hints 

424 jsonable_dict['classifier_global'] =\ 

425 classifier_to_jsonable_dict(self.global_clf, skipkeys=common_meta_keys, include_typesdict=True) 

426 decode_types_dict['classifiers_all'] = jsonable_dict['classifier_global']['__decode_types'] 

427 del jsonable_dict['classifier_global']['__decode_types'] 

428 

429 # get jsonable dicts for each classifier of self.MLdict and add corresponding decoding type hints 

430 jsonable_dict['classifiers_optimized'] =\ 

431 {i: classifier_to_jsonable_dict(clf, skipkeys=common_meta_keys) 

432 for i, clf in self.MLdict.items()} 

433 

434 # add common metadata and corresponding decoding type hints 

435 for k in common_meta_keys: 

436 jsonable_dict[k], decode_type = get_jsonable_value(getattr(self, k), return_typesdict=True) 

437 

438 if decode_type: 

439 decode_types_dict[k] = decode_type 

440 

441 jsonable_dict['__decode_types'] = decode_types_dict 

442 

443 return jsonable_dict 

444 

445 def save_to_json(self, filepath): 

446 jsonable_dict = self.to_jsonable_dict() 

447 

448 # Create json and save to file 

449 json_txt = json.dumps(jsonable_dict, sort_keys=True, indent=4) 

450 with open(filepath, 'w') as file: 

451 file.write(json_txt) 

452 

453 

454class ClassifierCollection(object): 

455 def __init__(self, path_dillFile): 

456 with open(path_dillFile, 'rb') as inF: 

457 self.content = dill.load(inF) 

458 

459 def __repr__(self): 

460 """Return the representation of ClassifierCollection. 

461 

462 :return: e.g., "{'1__2__3__4__5__7': {('Landsat-5', 'TM'): {'1__2__3__4__5__7': 

463 LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)}, ..." 

464 """ 

465 return pformat(self.content) 

466 

467 def __getitem__(self, item): 

468 """Get a specific item of the ClassifierCollection.""" 

469 try: 

470 return self.content[item] 

471 except KeyError: 

472 raise(KeyError("The classifier has no key '%s'. Available keys are: %s" 

473 % (item, list(self.content.keys())))) 

474 # def save_to_json(self, filepath): 

475 # a = 1 

476 # pass 

477 

478 

479def get_jsonable_value(in_value, return_typesdict=False): 

480 if isinstance(in_value, np.ndarray): 

481 outval = in_value.tolist() 

482 elif isinstance(in_value, list): 

483 outval = np.array(in_value).tolist() 

484 # json.dumps(outval) 

485 else: 

486 outval = in_value 

487 

488 # FIXME: In case of quadratic regression, there are some attributes that are not directly JSONable in this manner. 

489 

490 # create a dictionary containing the data types needed for JSON decoding 

491 typesdict = dict() 

492 if return_typesdict and not isinstance(in_value, (str, int, float, bool)) and in_value is not None: 

493 typesdict['type'] = type(in_value).__name__ 

494 

495 if isinstance(in_value, np.ndarray): 

496 typesdict['dtype'] = in_value.dtype.name 

497 

498 if isinstance(in_value, list): 

499 typesdict['dtype'] = type(in_value[0]).__name__ 

500 

501 if not len(set(type(vv).__name__ for vv in in_value)) == 1: 

502 raise RuntimeError('Lists containing different data types of list elements cannot be made ' 

503 'jsonable without losses.') 

504 

505 if return_typesdict: 

506 return outval, typesdict 

507 else: 

508 return outval 

509 

510 

511def classifier_to_jsonable_dict(clf, skipkeys: list = None, include_typesdict=False): 

512 from sklearn.linear_model import LinearRegression # avoids static TLS error here 

513 

514 if isinstance(clf, LinearRegression): 

515 jsonable_dict = dict(clftype='LR') 

516 typesdict = dict() 

517 

518 for k, v in clf.__dict__.items(): 

519 if skipkeys and k in skipkeys: 

520 continue 

521 

522 if include_typesdict: 

523 jsonable_dict[k], typesdict[k] = get_jsonable_value(v, return_typesdict=True) 

524 else: 

525 jsonable_dict[k] = get_jsonable_value(v) 

526 

527 # if valtype is np.ndarray: 

528 # jsonable_dict[k] = dict(val=v.tolist(), 

529 # dtype=v.dtype.name) 

530 # elif valtype is list: 

531 # jsonable_dict[k] = dict(val=np.array(v).tolist()) 

532 # else: 

533 # jsonable_dict[k] = dict(val=v) 

534 # 

535 # jsonable_dict[k]['valtype'] = valtype.__name__ 

536 

537 else: # Ridge, Pipeline, RandomForestRegressor: 

538 # TODO 

539 raise NotImplementedError('At the moment, only LR classifiers can be serialized to JSON format.') 

540 

541 if include_typesdict: 

542 jsonable_dict['__decode_types'] = {k: v for k, v in typesdict.items() if v} 

543 

544 return jsonable_dict 

545 

546 

547def classifier_from_json_str(json_str): 

548 """Create a spectral harmonization classifier from a JSON string (JSON de-serialization). 

549 

550 :param json_str: the JSON string to be used for de-serialization 

551 :return: 

552 """ 

553 from sklearn.linear_model import LinearRegression # avoids static TLS error here 

554 

555 in_dict = json.loads(json_str) 

556 

557 if in_dict['clftype']['val'] == 'LR': 

558 clf = LinearRegression() 

559 else: 

560 raise NotImplementedError("Unknown object type '%s'." % in_dict['objecttype']) 

561 

562 for k, v in in_dict.items(): 

563 try: 

564 val2set = getattr(builtins, v['valtype'])(v['val']) 

565 except (AttributeError, KeyError): 

566 if v['valtype'] == 'ndarray': 

567 val2set = np.array(v['val']).astype(np.dtype(v['dtype'])) 

568 else: 

569 raise TypeError("Unexpected object type '%s'." % v['valtype']) 

570 

571 setattr(clf, k, val2set) 

572 

573 return clf