from ..models.shred import SHRED
from ..processor.data_manager import DataManager
import numpy as np
import torch
import pandas as pd
from typing import Union, Dict
from ..processor.utils import *
from ..models.latent_forecaster_models.sindy import SINDy_Forecaster
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
[docs]
class SHREDEngine:
"""
High-level interface for SHRED model inference and evaluation.
Parameters
----------
data_manager : DataManager
Prepared data manager with fitted scalers.
shred_model : SHRED
Trained SHRED model instance.
Attributes
----------
dm : DataManager
The data manager for preprocessing and postprocessing.
model : SHRED
The trained SHRED model.
"""
def __init__(self, data_manager: DataManager, shred_model: SHRED):
"""
Initialize the SHRED inference engine.
Parameters
----------
data_manager : DataManager
Already prepared DataManager with fitted scalers.
shred_model : SHRED
Trained SHRED model instance.
"""
self.dm = data_manager
self.model = shred_model
# ensure model is in eval mode
self.model.eval()
[docs]
def sensor_to_latent(self, sensor_measurements: Union[np.ndarray, torch.Tensor, pd.DataFrame]) -> np.ndarray:
"""
Convert raw sensor measurements into latent-space embeddings.
Parameters
----------
sensor_measurements : array-like of shape (T, n_sensors)
Raw sensor time series.
Returns
-------
latents : np.ndarray of shape (T, latent_dim)
The GRU/LSTM final-hidden-state at each time index.
"""
# 1) Pull out raw numpy array
if isinstance(sensor_measurements, pd.DataFrame):
sensor_measurements = sensor_measurements.values
elif torch.is_tensor(sensor_measurements):
sensor_measurements = sensor_measurements.detach().cpu().numpy()
elif isinstance(sensor_measurements, np.ndarray):
sensor_measurements = sensor_measurements
else:
raise TypeError(f"Unsupported type {type(sensor_measurements)} for sensor_measurements")
# 2) Scale using the DataManager's fitted scaler (shape -> (T, n_sensors))
scaled_sensor_measurements = self.dm.sensor_scaler.transform(sensor_measurements)
# 3) Build lagged windows (shape -> (T, lags, n_sensors))
lags = self.dm.lags
lagged = generate_lagged_sensor_measurements(scaled_sensor_measurements, lags)
# 4) To torch on same device as model:
device = next(self.model.parameters()).device
X = torch.tensor(lagged, dtype=torch.float32, device=device)
# 5) Run through sequence to get latent:
with torch.no_grad():
# assumes your model._seq_model_outputs returns shape (T, latent_dim) when sindy=False
latents = self.model._seq_model_outputs(X, sindy=False)
# latents is a torch.Tensor shape (T, latent_dim)
# 6) Return as numpy
return latents.cpu().numpy()
[docs]
def forecast_latent(self, h, init_latents):
"""
Forecast future latent states using the latent forecaster.
Parameters
----------
h : int
Number of future timesteps to forecast.
init_latents : np.ndarray or torch.Tensor
Initial latent states for seeding the forecast.
Returns
-------
np.ndarray
Forecasted latent states.
Raises
------
RuntimeError
If no latent forecaster is available.
"""
if self.model.latent_forecaster is None:
raise RuntimeError("No `latent_forecaster` available. Please initialize SHRED with a " \
"`latent_forecaster` model.")
if isinstance(init_latents, torch.Tensor):
init_latents = init_latents.detach().cpu().numpy()
return self.model.latent_forecaster.forecast(h, init_latents)
[docs]
def decode(self, latents):
"""
Decode latent states back to full physical state space.
Parameters
----------
latents : np.ndarray or torch.Tensor
Latent representations to decode.
Returns
-------
dict
Dictionary mapping dataset IDs to reconstructed physical states.
"""
device = next(self.model.decoder.parameters()).device
if isinstance(latents, np.ndarray):
latents = torch.from_numpy(latents).to(device).float()
else:
latents = latents.to(device).float()
self.model.decoder.eval()
with torch.no_grad():
output = self.model.decoder(latents)
output = output.detach().cpu().numpy()
output = self.dm.data_scaler.inverse_transform(output)
results = {}
start_index = 0
for id in self.dm._dataset_ids:
length = self.dm._dataset_lengths.get(id)
Vt = self.dm._Vt_registry.get(id)
preSVD_scaler = self.dm._preSVD_scaler_registry.get(id)
spatial_shape = self.dm._dataset_spatial_shape.get(id)
dataset = output[:,start_index:start_index+length]
if Vt is not None:
dataset = dataset @ Vt
if preSVD_scaler is not None:
dataset = preSVD_scaler.inverse_transform(dataset)
original_shape = (dataset.shape[0],) + spatial_shape
results[id] = dataset.reshape(original_shape)
start_index = length + start_index
return results
[docs]
def evaluate(
self,
sensor_measurements: np.ndarray,
Y: Dict[str, np.ndarray] # raw full‐state, exactly like decode() returns
) -> pd.DataFrame:
"""
Performs end‐to‐end reconstruction error in the *physical* space.
Parameters
----------
sensor_measurements : (T, n_sensors)
The test sensor time series.
Y : dict[id] -> array (T, *spatial_shape)
The *raw* full‐state ground truth for each dataset id.
Returns
-------
DataFrame indexed by dataset id with columns [MSE, RMSE, MAE, R2].
"""
# 1) Get the model's reconstruction in raw space
latents = self.sensor_to_latent(sensor_measurements)
recon_dict = self.decode(latents) # dict[id] -> (T, *spatial_shape)
# 2) Compute stats
records = []
for id, y_true in Y.items():
y_pred = recon_dict[id]
y_true_flat = y_true.reshape(y_true.shape[0], -1)
y_pred_flat = y_pred.reshape(y_pred.shape[0], -1)
mse = mean_squared_error(y_true_flat, y_pred_flat)
rmse = np.sqrt(mse)
mae = mean_absolute_error(y_true_flat, y_pred_flat)
r2 = r2_score(y_true_flat, y_pred_flat)
records.append({
"dataset": id, "MSE": mse, "RMSE": rmse,
"MAE": mae, "R2": r2
})
return pd.DataFrame.from_records(records).set_index("dataset")