Basic SHRED Tutorial on Sea Surface Temperature#

Import Libraries#

# PYSHRED
from pyshred import DataManager, SHRED, SHREDEngine, LSTM_Forecaster

# Other helper libraries
import matplotlib.pyplot as plt
from scipy.io import loadmat
import torch
import numpy as np

Load in SST Data#

sst_data = np.load("sst_data.npy")
# Plotting a single frame
plt.figure()
plt.imshow(sst_data[0]) 
plt.colorbar()
plt.show()

png

Initialize Data Manager#

manager = DataManager(
    lags = 52,
    train_size = 0.8,
    val_size = 0.1,
    test_size = 0.1,
)

Add datasets and sensors#

manager.add_data(
    data = "sst_data.npy",
    id = "SST",
    random = 3,
    # mobile=,
    # stationary=,
    # measurements=,
    compress=False,
)

Analyze sensor summary#

manager.sensor_summary_df
data id sensor_number type loc/traj
0 SST 0 stationary (random) (81, 131)
1 SST 1 stationary (random) (155, 263)
2 SST 2 stationary (random) (153, 169)
manager.sensor_measurements_df
SST-0 SST-1 SST-2
0 28.449999 2.93 0.94
1 28.619999 3.35 1.09
2 28.279999 3.03 1.53
3 28.169999 2.95 1.60
4 28.179999 3.01 1.95
... ... ... ...
1395 29.999999 -0.57 -1.06
1396 29.769999 -0.32 -1.21
1397 29.809999 -0.14 -0.92
1398 29.889999 0.00 -0.66
1399 29.969999 0.00 -0.50

1400 rows × 3 columns

Get train, validation, and test set#

train_dataset, val_dataset, test_dataset= manager.prepare()

Initialize SHRED#

shred = SHRED(sequence_model="LSTM", decoder_model="MLP", latent_forecaster="LSTM_Forecaster")

Fit SHRED#

val_errors = shred.fit(train_dataset=train_dataset, val_dataset=val_dataset, num_epochs=10, sindy_regularization=0)
print('val_errors:', val_errors)
Fitting SHRED...
Epoch 1: Average training loss = 0.079502
Validation MSE (epoch 1): 0.036644
Epoch 2: Average training loss = 0.036274
Validation MSE (epoch 2): 0.034130
Epoch 3: Average training loss = 0.033781
Validation MSE (epoch 3): 0.034199
Epoch 4: Average training loss = 0.033450
Validation MSE (epoch 4): 0.033875
Epoch 5: Average training loss = 0.033123
Validation MSE (epoch 5): 0.033502
Epoch 6: Average training loss = 0.032411
Validation MSE (epoch 6): 0.032883
Epoch 7: Average training loss = 0.028727
Validation MSE (epoch 7): 0.022372
Epoch 8: Average training loss = 0.018934
Validation MSE (epoch 8): 0.016510
Epoch 9: Average training loss = 0.016190
Validation MSE (epoch 9): 0.014858
Epoch 10: Average training loss = 0.015114
Validation MSE (epoch 10): 0.015208
val_errors: [0.0366443  0.03413025 0.0341986  0.03387529 0.03350158 0.03288335
 0.0223723  0.01651014 0.01485774 0.01520807]

Evaluate SHRED#

train_mse = shred.evaluate(dataset=train_dataset)
val_mse = shred.evaluate(dataset=val_dataset)
test_mse = shred.evaluate(dataset=test_dataset)
print(f"Train MSE: {train_mse:.3f}")
print(f"Val   MSE: {val_mse:.3f}")
print(f"Test  MSE: {test_mse:.3f}")
Train MSE: 0.012
Val   MSE: 0.015
Test  MSE: 0.017

Initialize SHRED Engine for Downstream Tasks#

engine = SHREDEngine(manager, shred)

Sensor Measurements to Latent Space#

test_latent_from_sensors = engine.sensor_to_latent(manager.test_sensor_measurements)

Forecast Latent Space (No Sensor Measurements)#

val_latents = engine.sensor_to_latent(manager.val_sensor_measurements)
init_latents = val_latents[-shred.latent_forecaster.seed_length:] # seed forecaster with final lag timesteps of latent space from val
h = len(manager.test_sensor_measurements)
test_latent_from_forecaster = engine.forecast_latent(h=h, init_latents=init_latents)

Decode Latent Space to Full-State Space#

test_prediction = engine.decode(test_latent_from_sensors) # latent space generated from sensor data
test_forecast = engine.decode(test_latent_from_forecaster) # latent space generated from latent forecasted (no sensor data)

Compare final frame in prediction and forecast to ground truth:

truth      = sst_data[-1]
prediction = test_prediction['SST'][-1]
forecast   = test_forecast['SST'][-1]

data   = [truth, prediction, forecast]
titles = ["Test Truth Ground Truth", "Test Prediction", "Test Forecast"]

vmin, vmax = np.min([d.min() for d in data]), np.max([d.max() for d in data])

fig, axes = plt.subplots(1, 3, figsize=(20, 4), constrained_layout=True)

for ax, d, title in zip(axes, data, titles):
    im = ax.imshow(d, vmin=vmin, vmax=vmax)
    ax.set(title=title)
    ax.axis("off")

fig.colorbar(im, ax=axes, label="Value", shrink=0.8)
<matplotlib.colorbar.Colorbar at 0x1c2521a0ee0>

png

Evaluate MSE on Ground Truth Data#

# Train
t_train = len(manager.train_sensor_measurements)
train_Y = {'SST': sst_data[0:t_train]}
train_error = engine.evaluate(manager.train_sensor_measurements, train_Y)

# Val
t_val = len(manager.test_sensor_measurements)
val_Y = {'SST': sst_data[t_train:t_train+t_val]}
val_error = engine.evaluate(manager.val_sensor_measurements, val_Y)

# Test
t_test = len(manager.test_sensor_measurements)
test_Y = {'SST': sst_data[-t_test:]}
test_error = engine.evaluate(manager.test_sensor_measurements, test_Y)

print('---------- TRAIN ----------')
print(train_error)
print('\n---------- VAL   ----------')
print(val_error)
print('\n---------- TEST  ----------')
print(test_error)
---------- TRAIN ----------
              MSE      RMSE       MAE        R2
dataset                                        
SST      0.806486  0.898045  0.490421  0.365509

---------- VAL   ----------
              MSE      RMSE       MAE        R2
dataset                                        
SST      1.038167  1.018905  0.550346 -0.301883

---------- TEST  ----------
              MSE     RMSE       MAE        R2
dataset                                       
SST      1.196159  1.09369  0.597865 -0.431854