SINDy-SHRED Tutorial on Sea Surface Temperature#
Import Libraries#
# PYSHRED
from pyshred import DataManager, SHRED, SHREDEngine, SINDy_Forecaster
# Other helper libraries
import matplotlib.pyplot as plt
from scipy.io import loadmat
import torch
import numpy as np
Load in SST Data#
sst_data = np.load("sst_data.npy")
# Plotting a single frame
plt.figure()
plt.imshow(sst_data[0])
plt.colorbar()
plt.show()
Initialize Data Manager#
manager = DataManager(
lags = 52,
train_size = 0.8,
val_size = 0.1,
test_size = 0.1,
)
Add datasets and sensors#
manager.add_data(
data = sst_data,
id = "SST",
random = 50,
# mobile=,
# stationary=,
# measurements=,
compress=False,
)
Analyze sensor summary#
manager.sensor_summary_df
data id | sensor_number | type | loc/traj | |
---|---|---|---|---|
0 | SST | 0 | stationary (random) | (40, 343) |
1 | SST | 1 | stationary (random) | (136, 255) |
2 | SST | 2 | stationary (random) | (66, 219) |
3 | SST | 3 | stationary (random) | (75, 5) |
4 | SST | 4 | stationary (random) | (2, 206) |
5 | SST | 5 | stationary (random) | (43, 229) |
6 | SST | 6 | stationary (random) | (121, 31) |
7 | SST | 7 | stationary (random) | (46, 136) |
8 | SST | 8 | stationary (random) | (43, 88) |
9 | SST | 9 | stationary (random) | (147, 153) |
10 | SST | 10 | stationary (random) | (105, 14) |
11 | SST | 11 | stationary (random) | (154, 16) |
12 | SST | 12 | stationary (random) | (142, 151) |
13 | SST | 13 | stationary (random) | (144, 320) |
14 | SST | 14 | stationary (random) | (171, 86) |
15 | SST | 15 | stationary (random) | (55, 318) |
16 | SST | 16 | stationary (random) | (154, 84) |
17 | SST | 17 | stationary (random) | (45, 16) |
18 | SST | 18 | stationary (random) | (142, 319) |
19 | SST | 19 | stationary (random) | (145, 241) |
20 | SST | 20 | stationary (random) | (82, 164) |
21 | SST | 21 | stationary (random) | (54, 195) |
22 | SST | 22 | stationary (random) | (105, 310) |
23 | SST | 23 | stationary (random) | (178, 123) |
24 | SST | 24 | stationary (random) | (112, 242) |
25 | SST | 25 | stationary (random) | (139, 358) |
26 | SST | 26 | stationary (random) | (149, 227) |
27 | SST | 27 | stationary (random) | (10, 98) |
28 | SST | 28 | stationary (random) | (109, 97) |
29 | SST | 29 | stationary (random) | (38, 245) |
30 | SST | 30 | stationary (random) | (97, 217) |
31 | SST | 31 | stationary (random) | (104, 23) |
32 | SST | 32 | stationary (random) | (155, 71) |
33 | SST | 33 | stationary (random) | (134, 41) |
34 | SST | 34 | stationary (random) | (66, 215) |
35 | SST | 35 | stationary (random) | (173, 194) |
36 | SST | 36 | stationary (random) | (172, 139) |
37 | SST | 37 | stationary (random) | (57, 337) |
38 | SST | 38 | stationary (random) | (56, 298) |
39 | SST | 39 | stationary (random) | (125, 245) |
40 | SST | 40 | stationary (random) | (160, 347) |
41 | SST | 41 | stationary (random) | (166, 283) |
42 | SST | 42 | stationary (random) | (85, 26) |
43 | SST | 43 | stationary (random) | (62, 346) |
44 | SST | 44 | stationary (random) | (96, 65) |
45 | SST | 45 | stationary (random) | (121, 335) |
46 | SST | 46 | stationary (random) | (128, 251) |
47 | SST | 47 | stationary (random) | (122, 62) |
48 | SST | 48 | stationary (random) | (96, 69) |
49 | SST | 49 | stationary (random) | (73, 237) |
manager.sensor_measurements_df
SST-0 | SST-1 | SST-2 | SST-3 | SST-4 | SST-5 | SST-6 | SST-7 | SST-8 | SST-9 | ... | SST-40 | SST-41 | SST-42 | SST-43 | SST-44 | SST-45 | SST-46 | SST-47 | SST-48 | SST-49 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 11.74 | 10.08 | 22.190000 | 0.0 | -1.80 | 11.12 | 25.069999 | 2.78 | 0.0 | 3.28 | ... | 0.15 | -0.0 | 0.0 | 20.200000 | 29.159999 | 22.689999 | 15.22 | 21.740000 | 28.719999 | 24.539999 |
1 | 11.67 | 10.21 | 22.100000 | 0.0 | -1.80 | 10.70 | 24.049999 | 2.12 | 0.0 | 3.61 | ... | 0.06 | -0.0 | 0.0 | 19.910000 | 28.129999 | 23.609999 | 14.96 | 22.589999 | 28.189999 | 24.279999 |
2 | 11.73 | 10.61 | 21.890000 | 0.0 | -1.80 | 10.29 | 24.849999 | 1.53 | 0.0 | 3.63 | ... | 0.31 | -0.0 | 0.0 | 19.170000 | 28.789999 | 23.079999 | 15.90 | 21.410000 | 28.639999 | 24.899999 |
3 | 11.33 | 10.91 | 21.600000 | 0.0 | -1.80 | 9.87 | 24.879999 | 1.68 | 0.0 | 3.72 | ... | 0.53 | -0.0 | 0.0 | 19.180000 | 28.539999 | 23.349999 | 17.87 | 22.350000 | 28.559999 | 24.389999 |
4 | 11.17 | 11.19 | 21.590000 | 0.0 | -1.80 | 9.45 | 24.889999 | 1.35 | 0.0 | 3.76 | ... | 0.52 | -0.0 | 0.0 | 18.770000 | 28.059999 | 23.999999 | 17.69 | 22.100000 | 28.449999 | 24.219999 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1395 | 15.89 | 8.09 | 24.459999 | 0.0 | -1.70 | 17.31 | 21.770000 | 17.67 | 0.0 | 1.56 | ... | -1.75 | -0.0 | 0.0 | 21.540000 | 27.819999 | 18.070000 | 11.61 | 16.390000 | 27.649999 | 27.269999 |
1396 | 15.61 | 8.22 | 24.159999 | 0.0 | -1.78 | 16.82 | 20.870000 | 16.34 | 0.0 | 1.83 | ... | -1.70 | -0.0 | 0.0 | 21.690000 | 28.039999 | 17.930000 | 11.92 | 16.970000 | 27.649999 | 26.809999 |
1397 | 15.17 | 8.34 | 24.369999 | 0.0 | -1.80 | 16.12 | 20.980000 | 14.77 | 0.0 | 1.41 | ... | -1.65 | -0.0 | 0.0 | 21.740000 | 28.209999 | 17.940000 | 12.23 | 17.360000 | 27.799999 | 26.839999 |
1398 | 14.79 | 8.75 | 24.199999 | 0.0 | -1.80 | 15.60 | 21.330000 | 13.03 | 0.0 | 1.87 | ... | -1.47 | -0.0 | 0.0 | 21.840000 | 28.019999 | 18.140000 | 12.83 | 17.410000 | 27.619999 | 26.679999 |
1399 | 14.71 | 8.75 | 24.239999 | 0.0 | -1.80 | 14.83 | 21.560000 | 11.23 | 0.0 | 2.12 | ... | -1.32 | -0.0 | 0.0 | 22.449999 | 28.499999 | 18.500000 | 12.71 | 17.740000 | 28.489999 | 26.969999 |
1400 rows × 50 columns
Get train, validation, and test set#
train_dataset, val_dataset, test_dataset= manager.prepare()
Initialize a latent forecaster#
latent_forecaster = SINDy_Forecaster(poly_order=1, include_sine=True, dt=1/5)
Initialize SHRED#
shred = SHRED(sequence_model="GRU", decoder_model="MLP", latent_forecaster=latent_forecaster)
Fit SHRED#
val_errors = shred.fit(train_dataset=train_dataset, val_dataset=val_dataset, num_epochs=10, sindy_thres_epoch=20, sindy_regularization=1)
print('val_errors:', val_errors)
Fitting SindySHRED...
Epoch 1: Average training loss = 0.091680
Validation MSE (epoch 1): 0.032684
Epoch 2: Average training loss = 0.032257
Validation MSE (epoch 2): 0.015280
Epoch 3: Average training loss = 0.020602
Validation MSE (epoch 3): 0.013595
Epoch 4: Average training loss = 0.018892
Validation MSE (epoch 4): 0.013358
Epoch 5: Average training loss = 0.017929
Validation MSE (epoch 5): 0.013441
Epoch 6: Average training loss = 0.017431
Validation MSE (epoch 6): 0.013291
Epoch 7: Average training loss = 0.017039
Validation MSE (epoch 7): 0.014045
Epoch 8: Average training loss = 0.016548
Validation MSE (epoch 8): 0.013414
Epoch 9: Average training loss = 0.015898
Validation MSE (epoch 9): 0.013693
Epoch 10: Average training loss = 0.015332
Validation MSE (epoch 10): 0.013768
val_errors: [0.03268439 0.01527982 0.01359512 0.01335787 0.01344112 0.01329135
0.01404459 0.01341446 0.0136935 0.01376815]
Evaluate SHRED#
train_mse = shred.evaluate(dataset=train_dataset)
val_mse = shred.evaluate(dataset=val_dataset)
test_mse = shred.evaluate(dataset=test_dataset)
print(f"Train MSE: {train_mse:.3f}")
print(f"Val MSE: {val_mse:.3f}")
print(f"Test MSE: {test_mse:.3f}")
Train MSE: 0.010
Val MSE: 0.014
Test MSE: 0.018
SINDy Discovered Latent Dynamics#
print(shred.latent_forecaster)
(x0)' = 0.269 1 + 0.326 x0 + -0.079 x1 + 0.486 x2
(x1)' = 0.449 1 + 0.333 x0 + -0.308 x1 + 0.189 x2
(x2)' = -0.492 1 + -0.529 x0 + 0.191 x1 + -0.292 x2
Initialize SHRED Engine for Downstream Tasks#
engine = SHREDEngine(manager, shred)
Sensor Measurements to Latent Space#
test_latent_from_sensors = engine.sensor_to_latent(manager.test_sensor_measurements)
Forecast Latent Space (No Sensor Measurements)#
val_latents = engine.sensor_to_latent(manager.val_sensor_measurements)
init_latents = val_latents[-1] # seed forecaster with final latent space from val
h = len(manager.test_sensor_measurements)
test_latent_from_forecaster = engine.forecast_latent(h=h, init_latents=init_latents)
Decode Latent Space to Full-State Space#
test_prediction = engine.decode(test_latent_from_sensors) # latent space generated from sensor data
test_forecast = engine.decode(test_latent_from_forecaster) # latent space generated from latent forecasted (no sensor data)
Compare final frame in prediction and forecast to ground truth:
truth = sst_data[-1]
prediction = test_prediction['SST'][-1]
forecast = test_forecast['SST'][-1]
data = [truth, prediction, forecast]
titles = ["Test Truth Ground Truth", "Test Prediction", "Test Forecast"]
vmin, vmax = np.min([d.min() for d in data]), np.max([d.max() for d in data])
fig, axes = plt.subplots(1, 3, figsize=(20, 4), constrained_layout=True)
for ax, d, title in zip(axes, data, titles):
im = ax.imshow(d, vmin=vmin, vmax=vmax)
ax.set(title=title)
ax.axis("off")
fig.colorbar(im, ax=axes, label="Value", shrink=0.8)
<matplotlib.colorbar.Colorbar at 0x250fe875750>
Evaluate MSE on Ground Truth Data#
# Train
t_train = len(manager.train_sensor_measurements)
train_Y = {'SST': sst_data[0:t_train]}
train_error = engine.evaluate(manager.train_sensor_measurements, train_Y)
# Val
t_val = len(manager.test_sensor_measurements)
val_Y = {'SST': sst_data[t_train:t_train+t_val]}
val_error = engine.evaluate(manager.val_sensor_measurements, val_Y)
# Test
t_test = len(manager.test_sensor_measurements)
test_Y = {'SST': sst_data[-t_test:]}
test_error = engine.evaluate(manager.test_sensor_measurements, test_Y)
print('---------- TRAIN ----------')
print(train_error)
print('\n---------- VAL ----------')
print(val_error)
print('\n---------- TEST ----------')
print(test_error)
---------- TRAIN ----------
MSE RMSE MAE R2
dataset
SST 0.580466 0.761883 0.425608 0.407465
---------- VAL ----------
MSE RMSE MAE R2
dataset
SST 0.936849 0.96791 0.498381 -0.579918
---------- TEST ----------
MSE RMSE MAE R2
dataset
SST 1.248095 1.117182 0.587646 -0.453009