Downstream Tasks#
Once you have trained a SHRED model, you can use it for various downstream sensing and prediction tasks. The SHREDEngine
object provides a convenient interface that wraps your trained SHRED model and DataManager, making it easy to:
encode sensor measurements to latent space
forecast future latent space states
decode latent space back to full-state space
evaluate full-state reconstructions from sensor measurements against the ground-truth
Initialize SHREDEngine#
To initialize SHREDEngine, pass in a DataManager
and a trained SHRED
model.
from pyshred import SHREDEngine
# Assuming you have a trained SHRED model and DataManager
engine = SHREDEngine(data_manager, trained_shred_model)
Encode sensor measurements to latent space#
The SHREDEngine
object provides a .sensor_to_latent()
method for generating the latent space associated with the raw sensor measurements.
The raw sensor measurements should have shape (time_steps, num_sensors) and the column order of sensors should match the .sensor_measurements_df
attribute in DataManager
.
# Raw sensor measurements (e.g., from physical sensors)
sensor_data = your_sensor_measurements # Shape: (time_steps, num_sensors)
# Convert to latent space
latent_representation = engine.sensor_to_latent(sensor_data)

Figure: The sensor_to_latent
method takes in raw sensor measurements, scales it, generates lagged sequences of the scaled sensor measurements, then passes it through SHRED’s sequence_model
to obtain the latent space associated with the raw sensor measurements.#
Note: the lagged sequences near the start are padded by zeros because there is not enough sensor measurements to look back on.
Forecast future latent space states#
The SHREDEngine
provides a .forecast_latent()
method that predicts future latent states starting from an initial sequence of latent vectors. This method is available if SHRED
’s .latent_forecaster
is not None
.
# Generate latent states from current sensor measurements
current_latents = engine.sensor_to_latent(sensor_measurements)
# Set forecast horizon (number of steps to predict into the future)
forecast_horizon = 50
# Prepare initial seed for forecasting
seed_length = shred.latent_forecaster.seed_length
init_latents = current_latents[-seed_length:]
# Forecast future latent states
forecasted_latents = engine.forecast_latent(h=forecast_horizon, init_latents=init_latents)

Figure: The forecast_latent
method takes in a forecast horizon (number of steps to forecast into the future) and an initial latent space seed. The timesteps requried for the seed depends on SHRED’s latent_forecaster
model selected. SINDy_forecaster
requires only a single latent space timestep. LSTM_forecaster
requires the latent space seed to be of length lags (lags set in LSTM_forecaster
). The seed is passed into the SHRED’s latent_forecaster
, which then forecasts the latent space h
timesteps into the future.#
Decode Latent Space Back to Full-State Space#
The SHREDEngine
provides a .decode()
method for converting latent representations back into the full-state space.
# Convert latent representations back to full state
full_state_reconstruction = engine.decode(latent_representation)
# You can also decode forecasted latent states
forecasted_full_state = engine.decode(forecasted_latent)

Figure: The decoder
method takes in a latent space, passes it through SHRED’s decoder_model
, and returns the full-state reconstruction of the latent space.#
Evaluate reconstructions from sensor measurements against ground truth#
The SHREDEngine
provides an .evaluate()
method that compares reconstructed full-state outputs (from sensor measurements) against the unprocessed ground truth. This method performs end-to-end evaluation in the physical space, automatically handling all necessary post-processing steps.
What the evaluation includes:#
Unscaling: Converts normalized predictions back to original scales using fitted scalers
Decompression: Projects back to full-state space if SVD compression was used during preprocessing
Dataset unstacking: Separates multiple datasets if more than one was added to the DataManager
Metric computation: Calculates comprehensive error metrics (MSE, RMSE, MAE, R²)
Usage example:#
# Prepare ground truth data as a dictionary mapping dataset IDs to arrays
# Each array should have shape (time_steps, *spatial_dimensions)
ground_truth = {
"dataset_1": full_state_data_1, # shape: (T, height, width) for 2D spatial data
"dataset_2": full_state_data_2 # shape: (T, nx, ny, nz) for 3D spatial data
}
# Evaluate model performance
evaluation_results = engine.evaluate(
sensor_measurements=test_sensor_data, # shape: (T, n_sensors)
Y=ground_truth
)
# Display results
print(evaluation_results)
# MSE RMSE MAE R2
# dataset
# dataset_1 0.123 0.351 0.245 0.892
# dataset_2 0.098 0.313 0.198 0.913