Learning to detect the onset of slow activity after a generalized tonic–clonic seizure

Background Sudden death in epilepsy (SUDEP) is a rare disease in US, however, they account for 8–17% of deaths in people with epilepsy. This disease involves complicated physiological patterns and it is still not clear what are the physio-/bio-makers that can be used as an indicator to predict SUDEP so that care providers can intervene and treat patients in a timely manner. For this sake, UTHealth School of Biomedical Informatics (SBMI) organized a machine learning Hackathon to call for advanced solutions https://sbmi.uth.edu/hackathon/archive/sept19.htm. Methods In recent years, deep learning has become state of the art for many domains with large amounts data. Although healthcare has accumulated a lot of data, they are often not abundant enough for subpopulation studies where deep learning could be beneficial. Taking these limitations into account, we present a framework to apply deep learning to the detection of the onset of slow activity after a generalized tonic–clonic seizure, as well as other EEG signal detection problems exhibiting data paucity. Results We conducted ten training runs for our full method and seven model variants, statistically demonstrating the impact of each technique used in our framework with a high degree of confidence. Conclusions Our findings point toward deep learning being a viable method for detection of the onset of slow activity provided approperiate regularization is performed.


Background
Recent advancements deep learning have significantly improved performance for classification and detection tasks [1,2]. However, generalization ability is still limited due to the lack of sufficient high-quality training data for many domains. This holds true for many problems in the biomedical domain where data is often limited (especially for sub-population studies), which constraints the capacity of highly powerful supervised deep learning frameworks [3]. Since deep learning is known for requiring a considerable amount of data [4], applying it to a problem such as detection of markers (onset of slow activity) to predict critical patterns in a rare disease like SUDEP is not straightforward.

Method
Our method attempts to build a framework to apply recent advancements in deep learning [2,[5][6][7] to detection problems such as detection of the onset of slow activity after a generalized tonic-clonic seizure, where Open Access *Correspondence: cs.vance@icloud.com 1 University of Houston, Houston, TX, USA Full list of author information is available at the end of the article availability of of training data is limited. We combine a variety of preprocessing (Resampling), regularization (Anti-aliased temporal downsampling [6], Global temporal downsampling [8], Global batch-wise z-scoring, Kernel regularization, [9]), and optimization (Batch size [10], Loss discount factor) techniques to work around the data paucity issue. We also develop a system for real-time visualization of our models predictions to emphasize which parts of the signal contributed most to the decision https ://www.youtu be.com/watch ?v=cDuRs h2pSR M.

Overview
From a high level, we feed an EEG sequence x into our binary classification model y = f (x) , which estimates the probability y ≈ P(y|x) that the sequence contains the onset of slow activity (i.e., label). The chosen model architecture is a residual neural network [11] utilizing stacked convolution layers [12], skip connections [11], batch normalization [7], downsampling [6], and nonlinear activation functions. We train our model using mini-batch stochastic gradient descent (SGD). We implemented our model using Python 3.7 and tensorflow.keras [13]. Our full source code is available on Github: https :// githu b.com/csvan ce/deep-onset -detec tion.

Data
The original source of the training data D contains variable length sequences composed of recordings from ten pairwise offsets of two adjacent EEG electrodes [14]: The sequences were recorded from 134 different patients, each with their own variable length sequence [14]. It follows that |D| = 134 . The EEG sampling rate F s is 200 Hz, and each timestep t n is labeled y ∈ {0, 1} for the presence of slow activity [14]. We create a training set T derived from this set in Sequence generation. The validation dataset V contains |V | = 12345 ten second sequences sampled from 34 patients with the same EEG channels and sampling rate [14]. Each sequence is labeled y ∈ {0, 1} . The validation set V has a class imbalance for label y, with |V pos | = 3, 219 and |V neg | = 9, 126.

Inputs / output format
Inputs Detection of the onset of slow activity requires detection within the a short time-span in order to be clinically useful. A sequence length of 10 seconds was chosen based on this requirement. It follows that the input sequence to the model contains len seq input = 10r = 2000 timesteps. Each training example contains ten sequences of pairwise offsets. Considering both the sequence length and number of channels, the input to our model has the shape (len seq input , |F |) = (2000, 10).
Outputs Our model estimates P(y|x) , which is a scalar value ranging between 0 and 1. Hence, the output of our model has the shape (1, )

Preprocessing
Sequence generation In order to make the maximum utilization of the original training data, we first create a set S pos of as many positive sequences with length len seq input = 2000 as possible for an individual patient, starting with t f = t onset , and stopping after t i = t onset . For memory efficiency, a stride of 5 was used during the creation of each sequence in S pos . We then create a disjoint set S neg by randomly sampling at most |S pos | negative sequences with replacement from a uniform distribution containing every possible negative training example (sequences with t f < t onset ) from the same patient. This process is repeated for each patient, and the final training set T contains the union of each S pos and S neg set.
Resampling Before training, 50% of generated sequences were randomly cropped relative to the first timestep, resulting in a new sequence seq ′ input with the relationship len seq ′ input = ulen seq input , where u ∈ [0.9, 1.1] is sampled from a uniform distribution. seq ′ input was then resampled to the original length len seq input = 2000 . While this is a commonly used image augmentation technique for object detection [15,16], it should also be beneficial here since we are interested in augmenting the temporal relationship between frequency and phase rather than the frequency and phase itself.

Network architecture
Other researchers have demonstrated success with residual neural network variants for detecting complicated patterns in signals [2]. Thus, we use a similar variation of ResNet as a starting point with pre-activation style blocks [5] as shown in Fig. 4. Through trial and error, the first few convolution layers use a D = 32 dimensional kernel, before increasing to 2D and ending with 4D. Increasing D = 32 by factors of 2 resulted in overfitting. Likewise, reducing D = 32 by factors of 2 resulted in underfitting. With D = 32 , our model has p = 165,664 trainable parameters.
Anti-aliased temporal downsampling We explored several different methods of temporal downsampling in our network architecture, as well as investigating recent advancements in reducing aliasing [6]. After deciding on other hyper parameters, we trained our network with an anti-aliased version of strided downsampling. We use a three point Gaussian low pass kernel with σ ≈ 0.79577 during downsampling. We use the same σ for each of the three downsampling operations to encourage the network to learn a feature representation increasingly focused on lower frequencies. Each downsampling operation divides the temporal axis of the sequence by two.
Global temporal downsampling Recent papers in deep learning have increasingly relied on global pooling layers to reduce the number of trainable parameters and improve generalization for a variety of problems [8,11,17]. We considered several different global downsampling strategies including global max pooling (GMP), global average pooling (GAP) [8], and flattening. GAP was excluded because it may not be able to effectively handle sequences where only a small percentage contains the onset. Flattening significantly increases the number of trainable parameters, and may bias towards certain parts of the sequence in the training set. GMP provides the largest activation value from each channel regardless of where it occurred. With these considerations in mind, GMP was selected for global temporal downsampling on the top of the network.

Online augmentation
During training, online augmentations were employed to help the network to learn how to handle differences in variance and bias from patient to patient. We employ global batch-wise z-scoring, when combined with a small stride size during sequence generation, smaller batch sizes, and sample-wise shuffling results in the network being forced to generalize to a considerable number of different scales and biases.
Global batch-wise z-scoring z-scoring was done along batch, temporal, and channel axes, normalizing the entire batch using a single mean and standard deviation. Let B n be a mini batch of shape (|B|, len seq input , |F |) = (16, 2000, 10) for a batch size of 16. Each mini batch B n is randomly sampled without replacement from a uniform distribution during the start of every training epoch. We calculate the mean µ batch and standard deviation σ batch by reducing all three axes to a single scalar value. We then apply standard z-scoring as follows B ′ train = B train −µ batch σ batch . B ′ train is then used to calculate the loss during training. When validating our models performance, we instead z-score the validation set using the training set population mean and standard deviation.

Loss
Since our neural network is a binary classifier, we used a binary cross-entropy based cost function to train the network.
Kernel regularization In order to encourage the model to not overemphasize a small subset of learned features which may be biased towards the training set, we used L 2 kernel regularization. = 0.01 was chosen for the L 2 penalty for all convolution kernels using through trial and error [9].
Loss discount factor While GMP may help with cases where only a small part of the onset is present, some positive sequences generated using our methodology only contain a small number of positive time steps which may negatively impact convergence. If more data was available, we could simply omit ambiguous regions during training. Due to data paucity however, another solution is needed. We define a cost discounting function α(p) where p is defined as the number of positive time steps in a sequence divided by the total length of the sequence: This effectively discounts loss during the first second after the onset, starting from complete discount at t onset = t final and ending with no discount at t onset = t final − r , with our discount linearly decreasing as t onset → t final − r . Since our classes are balanced, we chose to discount a proportional amount from all negative examples in order to avoid bias. Finally, we define our cost function as:

Optimization
We optimized our network during training using minibatch stochastic gradient descent (SGD).
Batch size We used a mini-batch size of 16 during each training step. While a much higher batch size could easily fit into memory during training, smaller batch sizes result in a wider range of scale and bias when utilizing batch-wise z-scoring. Smaller batch sizes have also been observed to have a regularizing effect on the model when training with SGD [10].
Training parameters We selected an initial learning rate of η i = 0.0001 , decaying by a factor of 2 every 15 epochs for a total of 75 epochs. Momentum was set to β = 0.9.
Experimental setup While developing our method, we observed a high variability of outcome with different random seeds. In order to test the reliability of our methods, we conducted ten runs using different random seeds with our method during training.
Method variants In addition to our full method, we applied the same experiment setup to different variants omitting batch-wise z-scoring, L 2 kernel regularization, and anti-aliased downsampling. For the z-scoring variant, we normalize each sequence with its own mean and standard deviation during training and validation. The L 2 variant simply omits the L 2 penalty. The method without anti-aliased down-sampling performs a strided down-sampling before the residual connection, and adds a max pooling layer on the residual in order to match the sequence lengths. Two additional variants use batch sizes of 32 and 64. Finally, we created a baseline variant without batch z-scoring, L 2 regularization, anti-aliased downsampling, and the discount factor. For this variant we selected to use a batch size of 64. All variants share the same ten random seeds used in the full method for comparison.
Metrics Due to class imbalance in the validation set, we use receiver operator characteristic area under curve (ROC-AUC) to evaluate the accuracy of our model. Despite the imbalance, are also interested in the trade off between sensitivity and specificity for each of our variants. To compute sensitivity and specificity, values of y pred > 0.5 are considered true, and values of y pred ≤ 0.5 are considered false. The same threshold also applies for accuracy.

Results
Accuracy over ten training runs is shown in Table 1. Table 2 shows the best single validation ROC-AUC of each variant. Finally, Table 3 shows the result of 20 additional training runs for our full method.

Average accuracy
Our full model had the highest average ROC-AUC and highest and most consistent accuracy out of each of our variants. In our variant which omitted batch-wise z-scoring, we observe a significant increase in metric variance as well as the lowest average sensitivity and ROC-AUC. We hypothesize there is not enough variance in scale and bias in the training set without this augmentation. The variant without L 2 regularization struggled with ROC-AUC and specificity, while having slightly higher average sensitivity than our full method. Even considering the fact that our model only has ≈ 165 K trainable parameters, without L 2 kernel regularization there is clear evidence that a small number of features are overemphasized. Our variant without anti aliasing has a higher sensitivity than our full method. However, this comes at a significant cost in specificity. We hypothesize that this is due to the model associating aliasing with the presence of the onset, and that anti-aliasing and/or removal of high frequency information is important for reducing the frequency of false positives. The variant without loss discounting was the closest to our best results, trading off more specificity than was gained in sensitivity. In both cases, increasing the batch size from 16 has a

Maximum accuracy
We observe our full method has highest single epoch ROC-AUC of each variant. All of our variants appear to be heavily dependent on weight initialization and minibatch batch selection during training, with many separate training runs needed to achieve highest generalization. We hypothesize that this is due to both the paucity of the data set and unstable gradients caused by lower batch sizes.
In addition to the ten runs for our full method, we conducted approximately twenty additional runs for our full method with new random seeds. In Table 3, We show the best overall model in terms of ROC-AUC. The model has much higher sensitivity without sacrificing a significant amount of specificity. We use this model for all following discussion and visualization of model behavior.

Explaining model predictions
Salience In order to help explain our models predictions, we computed the gradient of y with respect to input sequences from the test set and summed the absolute value of the gradient for each feature channel together: For visualization purposes, we normalize salience with the timestep containing the maximum value: salience vis (t) = salience(t) salience(t max ) . In each visualization we see only strong, sparse activation contributing to the models decision due to the GMP layer at the top of the network.
Example: true positive Arguably the strongest activation overall appears to happen when almost every channel simultaneously increases, which can happen several times around the onset. We visualize this in Fig. 1, where observe strong activation on the rising edge of a global increase.
Example: false negative Only some instances of the onset of slow activity exhibit strong cross channel correlation, as demonstrated in Fig. 2. While most channels appear to move simultaneously, there is less positive correlation as well as some negative correlation between channels. In this particular example, there appears to be a wide spread of channel bias and low dynamic range. We hypothesis that z-scoring using the population mean and standard deviation may not be optimal for all examples, and that an adaptive strategy could improve validation performance.
Example: false positive Fig. 3 demonstrates that not all instances of cross channel correlation are useful for predicting the onset by themselves. We hypothesize salience(t) = that a model may need to take into account the temporal nature of the problem in order to avoid these types of false positives.

Conclusions
While our naive baseline model had relatively poor accuracy, we demonstrated the impact of many different regularization techniques. It follows that deep learning can be an effective tool for signal detection problems with a small amount of available training data. By conducting our experiment over many different training runs, we show the statistical significance of our results. Finally, we demonstrated that while our model may be a black box, we can make the results easier to interpret with salience and effective visualization.

Future work
We recognize that the loss discount factor could be made into a continuous function across the entire sequence.
Currently, examples with a negative label could contain the start of the onset due to the the labeling task being particularly challenging, but are weighted as heavily as non ambiguous examples. In addition, we observed examples of false positives which would be relatively easy for a human to classify correctly due to drastic changes in overall behavior patterns. An improved model would be able to recognize these changes over time in addition to identifying channel cross correlation.