ADM reconstruction

!pip install git+https://github.com/ferqui/DynapSEtorch.git
import torch
import numpy as np
from dynapsetorch.model import ADM as ADMtorch
from dynapsetorch.datasets.EMG import RoshamboDataset
from torch.optim import Adamax, SGD

from tqdm import trange
import matplotlib.pyplot as plt

We initialize the dataset that we are going to use to optimize the ADM encoder.

np.random.seed(3)
dt = 1e-4
upsample = int((1 / 200) / dt)
dataset = RoshamboDataset("~/Work/datasets", upsample=upsample)  # 5 ~ dt=1ms
idx = np.argsort(np.random.rand(len(dataset)))
n_batch = 1
emg = []
for i in range(n_batch):
    kk, _ = dataset[idx[i]]
    emg.append(kk)
emg = torch.stack(emg, dim=0)

We created the ADM encoder and the optimizer for “learning” the ADM UP and DOWN thresholds

encoder = ADMtorch(8, 1.0, 1.0, 1)
optimizer = SGD(encoder.parameters(), lr=1e-4)

pbar = trange(100)
for epoch in pbar:
    spikes = []
    emg = emg.detach()
    encoder.refrac = None
    encoder.DC_Voltage = None
    for t in range(emg.shape[1]):
        o, o_p, o_n = encoder(emg[:, t])
        spikes.append(o)
    spikes = torch.stack(spikes, dim=1)

    rec = encoder.reconstruct(spikes, initial_value=emg[:, 0, :].detach())

    loss = torch.mean((rec - emg.detach()) ** 2)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    pbar.set_postfix(
        loss=loss.item(),
        threshold_up=encoder.threshold.item(),
        threshold_down=encoder.threshold.item(),
    )
100%|██████████| 100/100 [02:18<00:00,  1.38s/it, loss=29.8, threshold_down=1.28, threshold_up=1.28]
batch = 0
t, idx = np.where(spikes[batch, :, [0, 8]].detach().numpy())

plt.figure(figsize=(12, 6))
plt.plot(emg[batch][:, 0].detach().numpy())
plt.plot(rec[batch][:, 0].detach().numpy())
plt.scatter(t, idx, marker=".", color="k")
plt.show()
../_images/05b20a3143081f1cacfaf6c0df2e6e622cd82d2ea253624a83827619f07f38e7.png