Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
3# spechomo, Spectral homogenization of multispectral satellite data
4#
5# Copyright (C) 2019-2021
6# - Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de)
7# - Helmholtz Centre Potsdam - GFZ German Research Centre for Geosciences Potsdam,
8# Germany (https://www.gfz-potsdam.de/)
9#
10# This software was developed within the context of the GeoMultiSens project funded
11# by the German Federal Ministry of Education and Research
12# (project grant code: 01 IS 14 010 A-C).
13#
14# Licensed under the Apache License, Version 2.0 (the "License");
15# you may not use this file except in compliance with the License.
16# You may obtain a copy of the License at
17#
18# http://www.apache.org/licenses/LICENSE-2.0
19#
20# Please note the following exception: `spechomo` depends on tqdm, which is
21# distributed under the Mozilla Public Licence (MPL) v2.0 except for the files
22# "tqdm/_tqdm.py", "setup.py", "README.rst", "MANIFEST.in" and ".gitignore".
23# Details can be found here: https://github.com/tqdm/tqdm/blob/master/LICENCE.
24#
25# Unless required by applicable law or agreed to in writing, software
26# distributed under the License is distributed on an "AS IS" BASIS,
27# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28# See the License for the specific language governing permissions and
29# limitations under the License.
31import os
32import tempfile
33import zipfile
34from collections import OrderedDict
35from pprint import pformat
36from typing import Union, List, TYPE_CHECKING # noqa F401 # flake8 issue
37import json
38import builtins
40if TYPE_CHECKING:
41 from matplotlib import pyplot as plt # noqa F401 # flake8 issue
43from tqdm import tqdm
44import dill
45import numpy as np
46from pandas import DataFrame
47from geoarray import GeoArray # noqa F401 # flake8 issue
49from .classifier_creation import get_filename_classifier_collection, get_machine_learner
50from .exceptions import ClassifierNotAvailableError
51from .utils import im2spectra, spectra2im
54class Cluster_Learner(object):
55 """
56 A class that holds the machine learning classifiers for multiple spectral clusters as well as a global classifier.
58 These classifiers can be applied to an input sensor image by using the predict method.
59 """
61 def __init__(self, dict_clust_MLinstances, global_classifier):
62 # type: (Union[dict, ClassifierCollection], any) -> None
63 """Get an instance of Cluster_Learner.
65 :param dict_clust_MLinstances: a dictionary of cluster specific machine learning classifiers
66 :param global_classifier: the global machine learning classifier to be applied at image positions with
67 high spectral dissimilarity to the available cluster centers
68 """
69 self.cluster_pixVals = list(sorted(dict_clust_MLinstances.keys())) # type: List[int]
70 self.MLdict = OrderedDict((clust, dict_clust_MLinstances[clust]) for clust in self.cluster_pixVals)
71 self.global_clf = global_classifier
72 sample_MLinst = list(self.MLdict.values())[0]
73 self.src_satellite = sample_MLinst.src_satellite
74 self.src_sensor = sample_MLinst.src_sensor
75 self.tgt_satellite = sample_MLinst.tgt_satellite
76 self.tgt_sensor = sample_MLinst.tgt_sensor
77 self.src_LBA = sample_MLinst.src_LBA
78 self.tgt_LBA = sample_MLinst.tgt_LBA
79 self.src_n_bands = sample_MLinst.src_n_bands
80 self.tgt_n_bands = sample_MLinst.tgt_n_bands
81 self.src_wavelengths = sample_MLinst.src_wavelengths
82 self.tgt_wavelengths = sample_MLinst.tgt_wavelengths
83 self.n_clusters = sample_MLinst.n_clusters
84 self.cluster_centers = np.array([cc.cluster_center for cc in self.MLdict.values()])
85 self.spechomo_version = \
86 sample_MLinst.spechomo_version if hasattr(sample_MLinst, 'spechomo_version') else 'NA'
87 self.spechomo_versionalias = \
88 sample_MLinst.spechomo_versionalias if hasattr(sample_MLinst, 'spechomo_versionalias') else 'NA'
90 @classmethod
91 def from_disk(cls, classifier_rootDir, method, n_clusters,
92 src_satellite, src_sensor, src_LBA, tgt_satellite, tgt_sensor, tgt_LBA, n_estimators=50):
93 # type: (str, str, int, str, str, list, str, str, list, int) -> Cluster_Learner
94 """Read a previously saved ClusterLearner from disk and return a ClusterLearner instance.
96 Describe the classifier specifications with the given arguments.
98 :param classifier_rootDir: root directory of the classifiers
99 :param method: harmonization method
100 'LR': Linear Regression
101 'RR': Ridge Regression
102 'QR': Quadratic Regression
103 'RFR': Random Forest Regression (50 trees; does not allow spectral sub-clustering)
104 :param n_clusters: number of clusters
105 :param src_satellite: source satellite, e.g., 'Landsat-8'
106 :param src_sensor: source sensor, e.g., 'OLI_TIRS'
107 :param src_LBA: source LayerBandsAssignment
108 :param tgt_satellite: target satellite, e.g., 'Landsat-8'
109 :param tgt_sensor: target sensor, e.g., 'OLI_TIRS'
110 :param tgt_LBA: target LayerBandsAssignment
111 :param n_estimators: number of estimators (only used in case of method=='RFR'
112 :return: classifier instance loaded from disk
113 """
114 # get path of classifier zip archive
115 args = (method, src_satellite, src_sensor, src_LBA, tgt_satellite, tgt_sensor, tgt_LBA)
116 kw_clfinit = dict(n_estimators=n_estimators)
118 if os.path.isfile(os.path.join(classifier_rootDir, '%s_classifiers.zip' % method)):
119 # get cluster specific classifiers and store them in a ClassifierCollection dictionary
120 dict_clust_MLinstances = cls._read_ClassifierCollection_from_zipFile(
121 classifier_rootDir, *args, n_clusters=n_clusters, **kw_clfinit)
123 # get the global classifier and add it as cluster label '-1'
124 global_clf = cls._read_ClassifierCollection_from_zipFile(
125 classifier_rootDir, *args, n_clusters=1, **kw_clfinit)[0]
127 elif os.path.isdir(classifier_rootDir):
128 # get cluster specific classifiers and store them in a ClassifierCollection dictionary
129 dict_clust_MLinstances = cls._read_ClassifierCollection_from_unzippedFile(
130 classifier_rootDir, *args, n_clusters=n_clusters, **kw_clfinit)
132 # get the global classifier and add it as cluster label '-1'
133 global_clf = cls._read_ClassifierCollection_from_unzippedFile(
134 classifier_rootDir, *args, n_clusters=1, **kw_clfinit)[0]
136 else:
137 raise FileNotFoundError("No '%s' classifiers available at %s." % (method, classifier_rootDir))
139 # create an instance of ClusterLearner based on the ClassifierCollection dictionary
140 return Cluster_Learner(dict_clust_MLinstances, global_clf)
142 @staticmethod
143 def _read_ClassifierCollection_from_zipFile(classifier_rootDir, method, src_satellite, src_sensor,
144 src_LBA, tgt_satellite, tgt_sensor, tgt_LBA, n_clusters,
145 **kw_clfinit):
146 # type: (str, str, str, str, list, str, str, list, int, dict) -> ClassifierCollection
148 path_classifier_zip = os.path.join(classifier_rootDir, '%s_classifiers.zip' % method)
150 # read requested classifier from zip archive and create a ClassifierCollection
151 with zipfile.ZipFile(path_classifier_zip, "r") as zf, tempfile.TemporaryDirectory() as td:
152 fName_clf = get_filename_classifier_collection(method, src_satellite, src_sensor, n_clusters=n_clusters,
153 **kw_clfinit)
154 try:
155 zf.extract(fName_clf, td)
156 except KeyError:
157 raise FileNotFoundError("No classifiers for %s %s with %d clusters contained in %s."
158 % (src_satellite, src_sensor, n_clusters, path_classifier_zip))
160 return Cluster_Learner._read_ClassifierCollection_from_unzippedFile(
161 td, method, src_satellite, src_sensor, src_LBA,
162 tgt_satellite, tgt_sensor, tgt_LBA, n_clusters)
164 @staticmethod
165 def _read_ClassifierCollection_from_unzippedFile(classifier_rootDir, method, src_satellite, src_sensor, src_LBA,
166 tgt_satellite, tgt_sensor, tgt_LBA, n_clusters, **kw_clfinit):
167 # type: (str, str, str, str, list, str, str, list, int, dict) -> ClassifierCollection
169 fName_clf_clustN = get_filename_classifier_collection(method, src_satellite, src_sensor,
170 n_clusters=n_clusters, **kw_clfinit)
171 path_classifier = os.path.join(classifier_rootDir, fName_clf_clustN)
173 # read requested classifier from zip archive and create a ClassifierCollection
174 try:
175 clf_collection = \
176 ClassifierCollection(path_classifier)['__'.join(src_LBA)][tgt_satellite, tgt_sensor]['__'.join(tgt_LBA)]
177 except KeyError:
178 raise ClassifierNotAvailableError(method, src_satellite, src_sensor, src_LBA,
179 tgt_satellite, tgt_sensor, tgt_LBA, n_clusters)
181 # validation
182 expected_MLtype = type(get_machine_learner(method))
183 if len(clf_collection.keys()) != n_clusters:
184 raise RuntimeError('Read classifier with %s clusters instead of %s.'
185 % (len(clf_collection.keys()), n_clusters))
186 for label, ml in clf_collection.items():
187 if not isinstance(ml, expected_MLtype):
188 raise ValueError("The given dillFile %s contains a spectral cluster (label '%s') with a %s machine "
189 "learner instead of the expected %s."
190 % (os.path.basename(path_classifier), label, type(ml), expected_MLtype.__name__,))
192 return clf_collection
194 def __iter__(self):
195 for cluster in self.cluster_pixVals:
196 yield self.MLdict[cluster]
198 def predict(self, im_src, cmap, nodataVal=None, cmap_nodataVal=None, cmap_unclassifiedVal=-1):
199 # type: (Union[np.ndarray, GeoArray], np.ndarray, Union[int, float], Union[int, float], Union[int, float]) -> np.ndarray # noqa
200 """Predict target satellite spectral information using separate prediction coefficients for spectral clusters.
202 :param im_src: input image to be used for prediction
203 :param cmap: classification map that assigns each image spectrum to a corresponding cluster
204 -> must be a 2D np.ndarray with the same X-/Y-dimension like im_src
205 :param nodataVal: nodata value to be used to fill into the predicted image
206 :param cmap_nodataVal: nodata class value of the nodata class of the classification map
207 :param cmap_unclassifiedVal: 'unclassified' class value of the nodata class of the classification map
208 :return:
209 """
210 cluster_labels = sorted(list(np.unique(cmap)))
212 im_pred = np.full((im_src.shape[0], im_src.shape[1], self.tgt_n_bands),
213 fill_value=nodataVal if nodataVal is not None else 0,
214 dtype=np.float32)
216 if len(cluster_labels) > 1:
217 # iterate over all cluster labels and apply corresponding machine learner parameters
218 # to predict target spectra
219 for pixVal in cluster_labels:
220 if pixVal == cmap_nodataVal:
221 # at nodata positions, the predicted pixel value will also be nodata
222 # -> don't do anything because im_pred was already initialized as an array full of nodata values
223 continue
225 elif pixVal == cmap_unclassifiedVal:
226 # apply global homogenization coefficients
227 classifier = self.global_clf
229 else:
230 # apply cluster specific homogenization coefficients
231 classifier = self.MLdict[pixVal]
233 mask_pixVal = cmap == pixVal
234 im_pred[mask_pixVal] = classifier.predict(im_src[mask_pixVal])
236 else:
237 # predict target spectra directly (much faster than the above algorithm)
238 pixVal = cluster_labels[0]
240 if pixVal == cmap_nodataVal:
241 # im_src consists only of no data values
242 pass # im_pred keeps at nodataVal
244 else:
245 if pixVal == cmap_unclassifiedVal:
246 # apply global homogenization coefficients
247 classifier = self.global_clf
248 else:
249 classifier = self.MLdict[pixVal]
250 assert classifier.clusterlabel == pixVal
252 spectra = im2spectra(im_src)
253 spectra_pred = classifier.predict(spectra)
254 im_pred = spectra2im(spectra_pred, im_src.shape[0], im_src.shape[1])
256 return im_pred # float32 array
258 def predict_weighted_averages(self, im_src, cmap_3D, weights_3D=None, nodataVal=None,
259 cmap_nodataVal=None, cmap_unclassifiedVal=-1):
260 # type: (Union[np.ndarray, GeoArray], np.ndarray, np.ndarray, Union[int, float], Union[int, float], Union[int, float]) -> np.ndarray # noqa
261 """Predict target satellite spectral information using separate prediction coefficients for spectral clusters.
263 NOTE: This version of the prediction function uses the prediction coefficients of multiple spectral clusters
264 and computes the result as weighted average of them. Therefore, the classification map must assign
265 multiple spectral clusters to each input pixel.
267 # NOTE: At unclassified pixels (cmap_3D[y,x,z>0] == -1) the prediction result using global coefficients
268 # is ignored in the weighted average. In that case the prediction result is based on the found valid
269 # spectral clusters and is not affected by the global coefficients (should improve prediction results).
271 :param im_src: input image to be used for prediction
272 :param cmap_3D: classification map that assigns each image spectrum to multiple corresponding clusters
273 -> must be a 3D np.ndarray with the same X-/Y-dimension like im_src
274 :param weights_3D:
275 :param nodataVal: nodata value to be used to fill into the predicted image
276 :param cmap_nodataVal: nodata class value of the nodata class of the classification map
277 :param cmap_unclassifiedVal: 'unclassified' class value of the nodata class of the classification map
278 :return:
279 """
280 if not cmap_3D.ndim > 2:
281 raise ValueError('Input classification map needs at least 2 bands to compute prediction results as'
282 'weighted averages.')
284 if cmap_3D.shape != weights_3D.shape:
285 raise ValueError("The input arrays 'cmap_3D' and 'weights_3D' need to have the same dimensions. "
286 "Received %s vs. %s." % (cmap_3D.shape, weights_3D.shape))
288 # predict for each classification map band
289 ims_pred_temp = []
291 for band in range(cmap_3D.shape[2]):
292 ims_pred_temp.append(
293 self.predict(im_src,
294 cmap_3D[:, :, band],
295 nodataVal=nodataVal,
296 cmap_nodataVal=cmap_nodataVal,
297 cmap_unclassifiedVal=cmap_unclassifiedVal
298 ))
300 # merge classification results by weighted averaging
301 nsamp = np.dot(*weights_3D.shape[:2])
302 nbandpred = ims_pred_temp[0].shape[2]
303 nbandscmap = weights_3D.shape[2]
305 weights = \
306 np.ones((nsamp, nbandpred, nbandscmap)) if weights_3D is None else \
307 np.tile(weights_3D.reshape((nsamp, 1, nbandscmap)),
308 (1, nbandpred, 1)) # nclust x n_tgt_bands x n_cmap_bands
310 # set weighting of unclassified pixel positions to zero (except from the first cmap band)
311 # -> see NOTE #2 in the docstring
312 # mask_unclassif = np.tile(cmap_3D.reshape(nsamp, 1, nbandscmap), (1, nbandpred, 1)) == cmap_unclassifiedVal
313 # mask_unclassif[:, :, :1] = False # if all other clusters are invalid, at least the first one is used for prediction # noqa
314 # weights[mask_unclassif] = 0
315 # FIXME this computes the prediction for all k-neighbors, no matter if the weights are 0
316 spectra_pred = np.average(np.dstack([im2spectra(im) for im in ims_pred_temp]),
317 weights=weights,
318 axis=2)
319 im_pred = spectra2im(spectra_pred,
320 tgt_rows=im_src.shape[0],
321 tgt_cols=im_src.shape[1])
323 return im_pred
325 def plot_sample_spectra(self, cluster_label='all', include_mean_spectrum=True, include_median_spectrum=True,
326 ncols=5, **kw_fig):
327 # type: (Union[str, int, List], bool, bool, int, dict) -> plt.figure
328 from matplotlib import pyplot as plt # noqa
330 if isinstance(cluster_label, int):
331 lbls2plot = [cluster_label]
332 elif isinstance(cluster_label, list):
333 lbls2plot = cluster_label
334 elif cluster_label == 'all':
335 lbls2plot = list(range(self.n_clusters))
336 else:
337 raise ValueError(cluster_label)
339 # create a single plot
340 if len(lbls2plot) == 1:
341 if cluster_label == 'all':
342 cluster_label = 0
344 fig, axes = plt.figure(), None
345 for i in range(100):
346 plt.plot(self.src_wavelengths, self.MLdict[cluster_label].cluster_sample_spectra[i, :])
348 plt.xlabel('wavelength [nm]')
349 plt.ylabel('%s %s\nreflectance [0-10000]' % (self.src_satellite, self.src_sensor))
350 plt.title('Cluster #%s' % cluster_label)
351 plt.grid(lw=0.2)
352 plt.ylim(0, 10000)
354 if include_mean_spectrum:
355 plt.plot(self.src_wavelengths, self.MLdict[cluster_label].cluster_center, c='black', lw=3)
356 if include_median_spectrum:
357 plt.plot(self.src_wavelengths, np.median(self.MLdict[cluster_label].cluster_sample_spectra, axis=0),
358 '--', c='black', lw=3)
360 # create a plot with multiple subplots
361 else:
362 nplots = len(lbls2plot)
363 ncols = nplots if nplots < ncols else ncols
364 nrows = nplots // ncols if not nplots % ncols else nplots // ncols + 1
365 figsize = (4 * ncols, 3 * nrows)
366 fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, sharex='all', sharey='all',
367 **kw_fig)
369 for lbl, ax in tqdm(zip(lbls2plot, axes.flatten()), total=nplots):
370 for i in range(100):
371 ax.plot(self.src_wavelengths, self.MLdict[lbl].cluster_sample_spectra[i, :], lw=1)
373 if include_mean_spectrum:
374 ax.plot(self.src_wavelengths, self.MLdict[lbl].cluster_center, c='black', lw=2)
375 if include_median_spectrum:
376 ax.plot(self.src_wavelengths, np.median(self.MLdict[lbl].cluster_sample_spectra, axis=0),
377 '--', c='black', lw=3)
379 ax.grid(lw=0.2)
380 ax.set_ylim(0, 10000)
382 if ax.get_subplotspec().is_last_row():
383 ax.set_xlabel('wavelength [nm]')
384 if ax.get_subplotspec().is_first_col():
385 ax.set_ylabel('%s %s\nreflectance [0-10000]' % (self.src_satellite, self.src_sensor))
386 ax.set_title('Cluster #%s' % lbl)
388 plt.tight_layout()
389 plt.show()
391 return fig, axes
393 def _collect_stats(self, cluster_label):
394 df = DataFrame(columns=['band', 'wavelength', 'RMSE', 'MAE', 'MAPE'])
395 df.band = self.tgt_LBA
396 df.wavelength = np.round(self.tgt_wavelengths, 1)
397 df.RMSE = np.round(self.MLdict[cluster_label].rmse_per_band, 1)
398 df.MAE = np.round(self.MLdict[cluster_label].mae_per_band, 1)
399 df.MAPE = np.round(self.MLdict[cluster_label].mape_per_band, 1)
401 overall_stats = dict(scores=self.MLdict[cluster_label].scores)
403 return df, overall_stats
405 def print_stats(self):
406 from tabulate import tabulate
408 for lbl in range(self.n_clusters):
409 print('Cluster #%s:' % lbl)
410 band_stats, overall_stats = self._collect_stats(lbl)
411 print(overall_stats)
412 print(tabulate(band_stats, headers=band_stats.columns))
413 print()
415 def to_jsonable_dict(self):
416 """Create a dictionary containing a JSONable replicate of the current Cluster_Learner instance."""
417 common_meta_keys = ['src_satellite', 'src_sensor', 'tgt_satellite', 'tgt_sensor', 'src_LBA', 'tgt_LBA',
418 'src_n_bands', 'tgt_n_bands', 'src_wavelengths', 'tgt_wavelengths', 'n_clusters',
419 'spechomo_version', 'spechomo_versionalias']
420 jsonable_dict = dict()
421 decode_types_dict = dict()
423 # get jsonable dict for global classifier and add decoding type hints
424 jsonable_dict['classifier_global'] =\
425 classifier_to_jsonable_dict(self.global_clf, skipkeys=common_meta_keys, include_typesdict=True)
426 decode_types_dict['classifiers_all'] = jsonable_dict['classifier_global']['__decode_types']
427 del jsonable_dict['classifier_global']['__decode_types']
429 # get jsonable dicts for each classifier of self.MLdict and add corresponding decoding type hints
430 jsonable_dict['classifiers_optimized'] =\
431 {i: classifier_to_jsonable_dict(clf, skipkeys=common_meta_keys)
432 for i, clf in self.MLdict.items()}
434 # add common metadata and corresponding decoding type hints
435 for k in common_meta_keys:
436 jsonable_dict[k], decode_type = get_jsonable_value(getattr(self, k), return_typesdict=True)
438 if decode_type:
439 decode_types_dict[k] = decode_type
441 jsonable_dict['__decode_types'] = decode_types_dict
443 return jsonable_dict
445 def save_to_json(self, filepath):
446 jsonable_dict = self.to_jsonable_dict()
448 # Create json and save to file
449 json_txt = json.dumps(jsonable_dict, sort_keys=True, indent=4)
450 with open(filepath, 'w') as file:
451 file.write(json_txt)
454class ClassifierCollection(object):
455 def __init__(self, path_dillFile):
456 with open(path_dillFile, 'rb') as inF:
457 self.content = dill.load(inF)
459 def __repr__(self):
460 """Return the representation of ClassifierCollection.
462 :return: e.g., "{'1__2__3__4__5__7': {('Landsat-5', 'TM'): {'1__2__3__4__5__7':
463 LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)}, ..."
464 """
465 return pformat(self.content)
467 def __getitem__(self, item):
468 """Get a specific item of the ClassifierCollection."""
469 try:
470 return self.content[item]
471 except KeyError:
472 raise(KeyError("The classifier has no key '%s'. Available keys are: %s"
473 % (item, list(self.content.keys()))))
474 # def save_to_json(self, filepath):
475 # a = 1
476 # pass
479def get_jsonable_value(in_value, return_typesdict=False):
480 if isinstance(in_value, np.ndarray):
481 outval = in_value.tolist()
482 elif isinstance(in_value, list):
483 outval = np.array(in_value).tolist()
484 # json.dumps(outval)
485 else:
486 outval = in_value
488 # FIXME: In case of quadratic regression, there are some attributes that are not directly JSONable in this manner.
490 # create a dictionary containing the data types needed for JSON decoding
491 typesdict = dict()
492 if return_typesdict and not isinstance(in_value, (str, int, float, bool)) and in_value is not None:
493 typesdict['type'] = type(in_value).__name__
495 if isinstance(in_value, np.ndarray):
496 typesdict['dtype'] = in_value.dtype.name
498 if isinstance(in_value, list):
499 typesdict['dtype'] = type(in_value[0]).__name__
501 if not len(set(type(vv).__name__ for vv in in_value)) == 1:
502 raise RuntimeError('Lists containing different data types of list elements cannot be made '
503 'jsonable without losses.')
505 if return_typesdict:
506 return outval, typesdict
507 else:
508 return outval
511def classifier_to_jsonable_dict(clf, skipkeys: list = None, include_typesdict=False):
512 from sklearn.linear_model import LinearRegression # avoids static TLS error here
514 if isinstance(clf, LinearRegression):
515 jsonable_dict = dict(clftype='LR')
516 typesdict = dict()
518 for k, v in clf.__dict__.items():
519 if skipkeys and k in skipkeys:
520 continue
522 if include_typesdict:
523 jsonable_dict[k], typesdict[k] = get_jsonable_value(v, return_typesdict=True)
524 else:
525 jsonable_dict[k] = get_jsonable_value(v)
527 # if valtype is np.ndarray:
528 # jsonable_dict[k] = dict(val=v.tolist(),
529 # dtype=v.dtype.name)
530 # elif valtype is list:
531 # jsonable_dict[k] = dict(val=np.array(v).tolist())
532 # else:
533 # jsonable_dict[k] = dict(val=v)
534 #
535 # jsonable_dict[k]['valtype'] = valtype.__name__
537 else: # Ridge, Pipeline, RandomForestRegressor:
538 # TODO
539 raise NotImplementedError('At the moment, only LR classifiers can be serialized to JSON format.')
541 if include_typesdict:
542 jsonable_dict['__decode_types'] = {k: v for k, v in typesdict.items() if v}
544 return jsonable_dict
547def classifier_from_json_str(json_str):
548 """Create a spectral harmonization classifier from a JSON string (JSON de-serialization).
550 :param json_str: the JSON string to be used for de-serialization
551 :return:
552 """
553 from sklearn.linear_model import LinearRegression # avoids static TLS error here
555 in_dict = json.loads(json_str)
557 if in_dict['clftype']['val'] == 'LR':
558 clf = LinearRegression()
559 else:
560 raise NotImplementedError("Unknown object type '%s'." % in_dict['objecttype'])
562 for k, v in in_dict.items():
563 try:
564 val2set = getattr(builtins, v['valtype'])(v['val'])
565 except (AttributeError, KeyError):
566 if v['valtype'] == 'ndarray':
567 val2set = np.array(v['val']).astype(np.dtype(v['dtype']))
568 else:
569 raise TypeError("Unexpected object type '%s'." % v['valtype'])
571 setattr(clf, k, val2set)
573 return clf