ADM reconstruction¶
!pip install git+https://github.com/ferqui/DynapSEtorch.git
import torch
import numpy as np
from dynapsetorch.model import ADM as ADMtorch
from dynapsetorch.datasets.EMG import RoshamboDataset
from torch.optim import Adamax, SGD
from tqdm import trange
import matplotlib.pyplot as plt
We initialize the dataset that we are going to use to optimize the ADM encoder.
np.random.seed(3)
dt = 1e-4
upsample = int((1 / 200) / dt)
dataset = RoshamboDataset("~/Work/datasets", upsample=upsample) # 5 ~ dt=1ms
idx = np.argsort(np.random.rand(len(dataset)))
n_batch = 1
emg = []
for i in range(n_batch):
kk, _ = dataset[idx[i]]
emg.append(kk)
emg = torch.stack(emg, dim=0)
We created the ADM encoder and the optimizer for “learning” the ADM UP and DOWN thresholds
encoder = ADMtorch(8, 1.0, 1.0, 1)
optimizer = SGD(encoder.parameters(), lr=1e-4)
pbar = trange(100)
for epoch in pbar:
spikes = []
emg = emg.detach()
encoder.refrac = None
encoder.DC_Voltage = None
for t in range(emg.shape[1]):
o, o_p, o_n = encoder(emg[:, t])
spikes.append(o)
spikes = torch.stack(spikes, dim=1)
rec = encoder.reconstruct(spikes, initial_value=emg[:, 0, :].detach())
loss = torch.mean((rec - emg.detach()) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_postfix(
loss=loss.item(),
threshold_up=encoder.threshold.item(),
threshold_down=encoder.threshold.item(),
)
100%|██████████| 100/100 [02:18<00:00, 1.38s/it, loss=29.8, threshold_down=1.28, threshold_up=1.28]
batch = 0
t, idx = np.where(spikes[batch, :, [0, 8]].detach().numpy())
plt.figure(figsize=(12, 6))
plt.plot(emg[batch][:, 0].detach().numpy())
plt.plot(rec[batch][:, 0].detach().numpy())
plt.scatter(t, idx, marker=".", color="k")
plt.show()