from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata
from astroML.datasets import sdss_corrected_spectra
# This function adjusts matplotlib settings for a uniform feel in the textbook.
# Note that with usetex=True, fonts are rendered with LaTeX. This may
# result in an error if LaTeX is not installed on your system. In that case,
# you can set usetex to False.
from astroML.plotting import setup_text_plots
setup_text_plots(fontsize=8, usetex=True)
# Fetch and prepare the data
data = sdss_corrected_spectra.fetch_sdss_corrected_spectra()
spectra = sdss_corrected_spectra.reconstruct_spectra(data)
wavelengths = sdss_corrected_spectra.compute_wavelengths(data)
# normalize spectra by integrated flux and subtract out mean, for easier training
spectranorms = np.mean(spectra, axis=1)
normedspectra = spectra / spectranorms[:, None]
meanspectrum = np.mean(normedspectra, axis=0)
normedspectra -= meanspectrum[None, :]
# split data into 3:1 train:test
torch.manual_seed(802) # seed used for book figure
dataset = torchdata.TensorDataset(torch.tensor(normedspectra))
trainnum = normedspectra.shape[0] // 4 * 3
traindata, testdata = torchdata.random_split(dataset, [trainnum, normedspectra.shape[0] - trainnum])
traindataloader = torchdata.DataLoader(traindata, batch_size=128, shuffle=True)
# define structure of variation autoencoder
class VAE(nn.Module):
def __init__(self, nhidden=250):
super(VAE, self).__init__()
self.encode_fc = nn.Linear(1000, nhidden)
self.mu = nn.Linear(nhidden, 2)
self.logvar = nn.Linear(nhidden, 2)
self.decode_fc = nn.Linear(2, nhidden)
self.output = nn.Linear(nhidden, 1000)
def encode(self, x):
h = F.relu(self.encode_fc(x))
return self.mu(h), self.logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
def decode(self, z):
h = F.relu(self.decode_fc(z))
return self.output(h)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
# add KL divergence to loss function
def VAEloss(criterion, recon_x, x, mu, logvar):
return criterion(recon_x, x) - 0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
def train_model():
model = VAE()
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=5, threshold=1e-3)
min_valid_loss = float('inf')
badepochs = 0
for t in range(1000):
train_loss = 0
for i, databatch in enumerate(traindataloader, 0):
specbatch = databatch[0]
optimizer.zero_grad()
recon, mu, logvar = model(specbatch)
loss = VAEloss(criterion, recon, specbatch, mu, logvar)
loss.backward()
optimizer.step()
train_loss += loss.item()
with torch.no_grad():
testspec = testdata[:][0]
recon, mu, logvar = model(testspec)
valid_loss = VAEloss(criterion, recon, testspec, mu, logvar)
if t % 10 == 0:
print('Epoch %3i: train loss %6.1f validation loss %6.1f' % \
(t, train_loss / len(traindata), valid_loss / len(testdata)))
# stop training if validation loss has not fallen in 10 epochs
if valid_loss > min_valid_loss*(1-1e-3):
badepochs += 1
else:
min_valid_loss = valid_loss
badepochs = 0
if badepochs == 10:
print('Finished training')
break
scheduler.step(valid_loss)
return model
model = train_model()
# plot results
with torch.no_grad():
# sort latent parameters from most constrained to least constrained
testspec = dataset[:][0]
recon, mu, logvar = model(testspec)
zorder = np.argsort(np.mean(logvar.numpy(), axis=0))
fig = plt.figure()
fig.subplots_adjust(hspace=0, wspace=0)
parvalues = [-2.,0.,2.]
for i, z1 in enumerate(parvalues):
for j, z2 in enumerate(parvalues):
# get z1 to vary left to right, z2 bottom to top
ax = fig.add_subplot(3, 3, (2-j)*len(parvalues)+i+1)
z = np.zeros((1,2), dtype=np.float32)
z[0, zorder] = z1, z2 # set z1 is more constrained of the two latent parameters
spectrum = model.decode(torch.tensor(z))
ax.plot(wavelengths, meanspectrum+spectrum.numpy()[0,:])
ax.text(6750, 3, '(%i, %i)' % (z1,z2))
ax.set_xlim(3000, 8000)
ax.set_ylim(-1, 4)
if i == 0 and j == 1:
ax.set_ylabel('flux')
else:
ax.yaxis.set_major_formatter(plt.NullFormatter())
if j == 0 and i == 1:
ax.set_xlabel(r'${\rm wavelength\ (\AA)}$')
else:
ax.xaxis.set_major_formatter(plt.NullFormatter())
plt.show()