from __future__ import annotations
import logging
import multiprocessing
import re
import time
from datetime import datetime
import click
import numpy as np
from scipy import ndimage
from scipy.io import loadmat
from scipy.linalg import norm
from spectral import envi
import isofit.utils.template_construction as tmpl
from isofit import ray
from isofit.core.common import envi_header, resample_spectrum, svd_inv
from isofit.core.fileio import IO, initialize_output, write_bil_chunk
from isofit.core.geometry import Geometry
from isofit.core.multistate import SurfaceMapping
from isofit.data import env
[docs]
class Component:
def __init__(self, model_dict):
# Combine all potential surface files
[docs]
self.model_dict = model_dict
[docs]
self.components = list(zip(self.model_dict["means"], self.model_dict["covs"]))
[docs]
self.n_comp = len(self.components)
[docs]
self.wl = self.model_dict["wl"][0]
[docs]
self.n_wl = len(self.wl)
[docs]
self.surface_categories = self.model_dict.get("surface_categories", [])
normalize = self.model_dict["normalize"]
if normalize == "Euclidean":
self.norm = lambda r: norm(r)
elif normalize == "RMS":
self.norm = lambda r: np.sqrt(np.mean(pow(r, 2)))
elif self.normalize == "None":
self.norm = lambda r: 1.0
else:
raise ValueError("Unrecognized Normalization: %s\n" % normalize)
refwl = np.squeeze(self.model_dict["refwl"])
refwl = refwl[(refwl < 900) | (refwl > 2000)]
idx_ref = [np.argmin(abs(self.wl - w)) for w in np.squeeze(refwl)]
[docs]
self.idx_ref = np.array(idx_ref)
self.Covs, self.Cinvs, self.mus = [], [], []
for i in range(self.n_comp):
Cov = self.components[i][1]
self.Covs.append(np.array([Cov[j, self.idx_ref] for j in self.idx_ref]))
self.Cinvs.append(svd_inv(self.Covs[-1]))
self.mus.append(self.components[i][0][self.idx_ref])
[docs]
def pickClosest(self, x, geom):
lamb_ref = x[self.idx_ref]
lamb_ref = (lamb_ref - np.min(lamb_ref)) / (np.max(lamb_ref) - np.min(lamb_ref))
mds = []
for ci in range(self.n_comp):
ref_mu = self.mus[ci]
ref_mu = (ref_mu - np.min(ref_mu)) / (np.max(ref_mu) - np.min(ref_mu))
mds.append(sum(pow(lamb_ref - ref_mu, 2)))
closest = np.argmin(mds)
surface_category = self.surface_categories[closest].strip()
surface_idx = SurfaceMapping[surface_category]
return surface_idx
@ray.remote(num_cpus=1)
[docs]
class Worker(object):
def __init__(
self,
rdn_file: str,
obs_file: str,
loc_file: str,
out_file: str,
model_dict: list,
wl: np.ndarray,
fwhm: np.ndarray,
dayofyear: int,
irr_file: str,
loglevel: str,
logfile: str,
):
logging.basicConfig(
format="%(levelname)s:%(asctime)s ||| %(message)s",
level=loglevel,
filename=logfile,
datefmt="%Y-%m-%d,%H:%M:%S",
)
[docs]
self.rdn = envi.open(envi_header(rdn_file)).open_memmap(interleave="bip")
[docs]
self.loc = envi.open(envi_header(loc_file)).open_memmap(interleave="bip")
[docs]
self.obs = envi.open(envi_header(obs_file)).open_memmap(interleave="bip")
[docs]
self.out_file = out_file
[docs]
self.out = envi.open(envi_header(out_file)).open_memmap(
interleave="bip", writable=True
)
[docs]
self.component = Component(model_dict)
[docs]
self.esd = IO.load_esd()
[docs]
self.dayofyear = dayofyear
[docs]
self.solar_irr = self.solarIrradiance(irr_file)
[docs]
def solarIrradiance(self, irr_file):
irr = np.loadtxt(irr_file, comments="#")
iwl, irr = irr.T
if iwl[0] > 100:
iwl = iwl / 1000.0
irr = irr / 10.0 # convert, uW/nm/cm2
irr_factor = self.esd[self.dayofyear - 1, 1]
irr = irr / irr_factor**2 # consider solar distance
return resample_spectrum(irr, iwl, self.wl, self.fwhm)
[docs]
def run_lines(self, startstop):
start_line, stop_line = startstop
output = self.out[start_line:stop_line, ...]
for r in range(start_line, stop_line):
for c in range(output.shape[1]):
meas = self.rdn[r, c, :]
geom = Geometry(
obs=self.obs[r, c, :], loc=self.loc[r, c, :], esd=self.esd
)
coszen = np.cos(np.deg2rad(geom.solar_zenith))
num = meas * np.pi
denom = self.solar_irr * coszen
x = num / denom
output[r - start_line, c, :] = self.component.pickClosest(x, geom)
unique, counts = np.unique(output[r - start_line, ...], return_counts=True)
logging.debug(f"Elements: {unique}")
logging.debug(f"Counts: {counts}")
write_bil_chunk(
np.swapaxes(output, 1, 2),
self.out_file,
start_line,
(self.out.shape[0], self.out.shape[1], 1),
)
[docs]
def load_surface_mat(
surface_files,
wavelength_file=None,
keys_to_combine=[
"means",
"covs",
"attribute_means",
"attribute_covs",
"surface_categories",
],
):
# CLI will pass .json or .mat as string.
if isinstance(surface_files, str):
if surface_files.endswith(".mat"):
surface_files = {"cli_input": surface_files}
elif surface_files.endswith(".json"):
print(surface_files)
print(wavelength_file)
surface_files = tmpl.check_surface_model(
surface_path=surface_files,
surface_wavelength_path=wavelength_file,
multisurface=True,
)
# Apply OE will pass dict
for i, (name, surface_file) in enumerate(surface_files.items()):
surface_model_dict = loadmat(surface_file)
if not i:
model_dict = surface_model_dict
else:
for key in keys_to_combine:
assert (
model_dict[key].ndim == surface_model_dict[key].ndim
), "Dimension mismatch between surface component cov matrices"
model_dict[key] = np.concatenate(
[model_dict[key], surface_model_dict[key]], axis=0
)
return model_dict
[docs]
def filter_image(out, thresh=100):
"""
Temporary function to clean the image.
Memory intensive.
Could try to make this recursive by nesting the bottom loop into the top loop
"""
# Don't do any filtering if array is uniform
if len(np.unique(out)) == 1:
return out
masks = []
for i in np.unique(out):
temp = out.copy()
temp[temp == i] = 9999
temp[temp < 9999] = 0
temp[temp == 9999] = 1
label, n = ndimage.label(temp)
sizes = ndimage.sum(temp, label, range(n + 1))
mask = sizes >= thresh
masks.append(mask[label])
for i, mask in enumerate(masks):
if not i:
final = mask.astype(int) * (i + 1)
else:
final += mask.astype(int) * (i + 1)
label, n = ndimage.label(final == 0)
for i in np.unique(label):
if not i:
continue
temp = label.copy()
temp[temp != i] = 0
temp[temp > 0] = 1
vals, counts = np.unique(
final[np.where(ndimage.binary_dilation(temp).astype(int) - temp)],
return_counts=True,
)
final[label == i] = vals[np.argmax(counts)]
return final - 1
[docs]
def multicomponent_classification(
rdn_file: str,
obs_file: str,
loc_file: str,
out_file: str,
surface_files: str,
wavelength_file: std,
n_cores: int = -1,
dayofyear: int = None,
irr_file: str = None,
clean: bool = False,
thresh: int = 100,
ray_address: str = None,
ray_redis_password: str = None,
ray_temp_dir=None,
ray_ip_head=None,
loglevel: str = "INFO",
logfile: str = None,
):
"""\
Classify a radiance file based on a per-pixel prior selection.
The classification leverages the same methodology ISOFIT uses
to select a prior distribution from an input .json or .mat file.
\b
Parameters
----------
rdn_file: str
Radiance data cube. Expected to be ENVI format
obs_file: str
Location data cube of shape (Lon, Lat, Elevation). Expected to be ENVI format
loc_file: str
Observation data cube of shape:
(path length, to-sensor azimuth, to-sensor zenith,
to-sun azimuth, to-sun zenith, phase,
slope, aspect, cosine i, UTC time)
Expected to be ENVI format
out_file: str
Output path to location where to save output file.
surface_files: str or dict
CLI entry into the classifier uses a .mat or a .json file
Apply OE entry into the classifier uses a dict argument.
wavelength_file: str
Standard ISOFIT wavelength file
n_cores : int, default=1
Number of cores to run classifier with.
dayofyear: int
Day of year for earth-sun distance calculation
irr_file: str
Path to irradiance file to use in the classification
clean: str
Experimental method to filter out noisy classification masks.
Creates connected binary components and filters out small features.
thresh: int
Threshold size to filter out features smaller than this number of pixels.
loglevel: str
Logging level to use (e.g. DEBUG, INFO, etc.)
logfile: str
Output location for logging file if writing to disk.
"""
logging.basicConfig(
format="%(levelname)s:%(asctime)s ||| %(message)s",
level=loglevel,
filename=logfile,
datefmt="%Y-%m-%d,%H:%M:%S",
)
# Get day of year from rdn string
if not dayofyear:
match = re.search("([0-9]{8}t[0-9]{6})", rdn_file)
if match:
dt = datetime.strptime(match.group(), "%Y%m%dt%H%M%S")
dayofyear = dt.timetuple().tm_yday
else:
logging.error("Could not find day of year from path")
raise ValueError("Could not find day of year from path")
if n_cores == -1:
n_cores = multiprocessing.cpu_count()
# Get wavelength
# Assumes a structure where
# column 0: idx
# column 1: wl
# column 2: fwhm
wl = np.loadtxt(wavelength_file)
fwhm = wl[:, 2]
wl = wl[:, 1]
# Check units of wavelength
if wl[0] > 100:
logging.info("Wavelength units of nm inferred...converting to microns")
wl = wl / 1000.0
fwhm = fwhm / 1000.0
# Check to see if irradiance file was passed
irr_path = [
"examples",
"20151026_SantaMonica",
"data",
"prism_optimized_irr.dat",
]
irr_file = irr_file if irr_file else str(env.path(*irr_path))
# The "mapping" is how the program moves between a int-classification
# And the surface model
model_dict = load_surface_mat(surface_files, wavelength_file)
surface_types = model_dict.get("surface_categories", [])
if not len(surface_types):
raise ValueError("No surface categories key in provided surface prior file.")
# Construct the output File
rdn_ds = envi.open(envi_header(rdn_file))
rdns = rdn_ds.shape
rdn_meta = rdn_ds.metadata
del rdn_ds
output = initialize_output(
{
"data type": 4,
"file type": "ENVI Standard",
"byte order": 0,
},
out_file,
(rdns[0], 1, rdns[1]),
lines=rdn_meta["lines"],
samples=rdn_meta["samples"],
interleave="bil",
bands="1",
band_names=["Classification"],
description=("Per-pixel multicomponent classification"),
)
# Ray initialization
ray_dict = {
"ignore_reinit_error": True,
"local_mode": n_cores == 1,
"address": ray_address,
"include_dashboard": False,
"_temp_dir": ray_temp_dir,
"_redis_password": ray_redis_password,
}
ray.init(**ray_dict)
if n_cores == 1:
n_workers = n_cores + 1
else:
n_workers = n_cores
line_breaks = np.linspace(0, rdns[0], n_workers, dtype=int)
line_breaks = [
(line_breaks[n], line_breaks[n + 1]) for n in range(len(line_breaks) - 1)
]
wargs = [
ray.put(obs)
for obs in (
rdn_file,
obs_file,
loc_file,
out_file,
model_dict,
wl,
fwhm,
dayofyear,
irr_file,
loglevel,
logfile,
)
]
workers = ray.util.ActorPool([Worker.remote(*wargs) for _ in range(n_workers)])
start_time = time.time()
res = list(workers.map_unordered(lambda a, b: a.run_lines.remote(b), line_breaks))
total_time = time.time() - start_time
logging.info(
f"Multicomponent classification complete. {round(total_time,2)}s total, "
f"{round(rdns[0]*rdns[1]/total_time,4)} spectra/s, "
f"{round(rdns[0]*rdns[1]/total_time/n_workers,4)} spectra/s/core"
)
if clean:
logging.info("Filtering classification image." f"Using thresh: {thresh}")
out = envi.open(envi_header(out_file)).open_memmap(
interleave="bip", writable=True
)
out_filter = filter_image(out.copy(), thresh)
out[...] = out_filter
@click.command(
name="multicomponent_classification",
help=multicomponent_classification.__doc__,
no_args_is_help=True,
)
@click.argument("rdn_file")
@click.argument("obs_file")
@click.argument("loc_file")
@click.argument("out_file")
@click.argument("surface_files")
@click.argument("wavelength_file")
@click.option("--n_cores", default=-1)
@click.option("--irr_file")
@click.option("--clean", is_flag=True)
@click.option("--thresh", default=100)
@click.option("--ray_address")
@click.option("--ray_redis_password")
@click.option("--ray_temp_dir")
@click.option("--ray_ip_head")
@click.option("--loglevel", default="INFO")
@click.option("--logfile")
[docs]
def cli(**kwargs):
multicomponent_classification(**kwargs)
click.echo("Done")