An interpretable risk prediction model for healthcare with pattern attention

Background The availability of massive amount of data enables the possibility of clinical predictive tasks. Deep learning methods have achieved promising performance on the tasks. However, most existing methods suffer from three limitations: (1) There are lots of missing value for real value events, many methods impute the missing value and then train their models based on the imputed values, which may introduce imputation bias. The models’ performance is highly dependent on the imputation accuracy. (2) Lots of existing studies just take Boolean value medical events (e.g. diagnosis code) as inputs, but ignore real value medical events (e.g., lab tests and vital signs), which are more important for acute disease (e.g., sepsis) and mortality prediction. (3) Existing interpretable models can illustrate which medical events are conducive to the output results, but are not able to give contributions of patterns among medical events. Methods In this study, we propose a novel interpretable Pattern Attention model with Value Embedding (PAVE) to predict the risks of certain diseases. PAVE takes the embedding of various medical events, their values and the corresponding occurring time as inputs, leverage self-attention mechanism to attend to meaningful patterns among medical events for risk prediction tasks. Because only the observed values are embedded into vectors, we don’t need to impute the missing values and thus avoids the imputations bias. Moreover, the self-attention mechanism is helpful for the model interpretability, which means the proposed model can output which patterns cause high risks. Results We conduct sepsis onset prediction and mortality prediction experiments on a publicly available dataset MIMIC-III and our proprietary EHR dataset. The experimental results show that PAVE outperforms existing models. Moreover, by analyzing the self-attention weights, our model outputs meaningful medical event patterns related to mortality. Conclusions PAVE learns effective medical event representation by incorporating the values and occurring time, which can improve the risk prediction performance. Moreover, the presented self-attention mechanism can not only capture patients’ health state information, but also output the contributions of various medical event patterns, which pave the way for interpretable clinical risk predictions. Availability The code for this paper is available at: https://github.com/yinchangchang/PAVE.


Background
With the increased growth of Electronic Health Records (EHRs) both in volume and diversity during the last decades, it becomes possible to apply clinical predictive models to improve the quality of clinical care. EHRs are temporal sequence data and consist of diagnosis codes, medications, lab results, and vital signs. Patient health information contained in the massive EHRs is extremely useful in different tasks within the medical domain, such as risk prediction [1,2], patient subtyping [3,4], treatment effect estimation [5,6], and patient similarity analysis [7]. In this paper, we focus on clinical risk prediction tasks. Most state-of-the-art clinical risk predictive models are based on deep learning, and trained in an end-to-end way. Recurrent Neural Network (RNN), a popular deep learning model for modeling sequences, has achieved good performance in clinical risk prediction tasks recently [8][9][10]. However, there are still some challenges in the field. (1) Most existing methods [11,12] represent medical events as embedding vectors, which lose real value information of the medical events (e.g., lab tests and vital signs). (2) Lab tests are diagnosis-driven and therefore EHRs have lots of missing value for lab tests. Many methods [13] impute the missing value and then train their models based on the imputed values. The models' performance is highly dependent on the imputation accuracy. (3) Existing interpretable models are only able to provide instance-wise variable importance (i.e., to compute each medical event's contribution to the disease risks) rather than pattern-wise importance. It is possible that when some clinical events occur simultaneously, it may lead to a sharp increase to risk while each event alone does not cause high risk.
In this study, we propose a new interpretable Pattern Attention model with Value Embedding (PAVE), which is totally based on attention mechanism. For each patient, medical events, values (e.g., lab test and vital sign values) and their corresponding occurring time are represented as embedding vectors and projected to a medical semantic space. Then a self-attention layer is leveraged to capture the meaningful patterns among medical events. A pattern attention module is proposed to attend to the event patterns and produce an attention vector for each patient. Finally, we use a fully connected layer to predict a patient's risk for future clinical outcomes. By analyzing the self-attention weights and pattern attention weights, our model is able to compute the contribution rates of various medical event patterns, thus paving the way for interpretable clinical risk predictions.
In order to demonstrate the effectiveness of the proposed PAVE, we compare our model against both traditional machine-learning methods (e.g., logistic regression, random forest) and recent deep-learning methods (e.g., RETAIN) on sepsis and mortality risk prediction tasks. We conducted experiments on both a publicly available MIMIC-III dataset [14] and our proprietary EHRs data. The experimental results show that PAVE outperforms all the baselines in both datasets and various settings, which demonstrates the effectiveness of the proposed model. Moreover, after PAVE is well trained, it is also able to find the EHRs event patterns with high contribution rates to high mortality risks. To highlight the handout of the proposed framework is as follows: • We propose a novel interpretable risk prediction model PAVE, which is based on a self-attention mechanism and achieves better performance than the baselines. • The presented self-attention mechanism can automatically capture meaningful patterns and is helpful to find the patterns related to high risks. To the best of our knowledge, this work is the first attempt to identify the contributions of patterns. • We propose a new value embedding that can map values into vectors, so we don't need to impute the missing values. • Our medical event embedding module can take medical events' occurring time into account.

Related work
Due to their promising performance in clinical risk prediction task, deep learning methods have attracted significant interest from healthcare researchers. In this section, we go through with the existing work related to deep learning models, including risk prediction, attention mechanism, and clinical models' interpretability.
can be used for patient subtyping [3], phenotyping [19], similarity measurement [7], and missing values imputation [13,20], which are highly related to risk prediction tasks. For some RNN based approaches, the relationships between subsequent visits are usually not considered. To address the issue, Dipole [11] adopts attention mechanisms to capture the visits' relations and therefore significantly improves the prediction accuracy. When preprocessing the EHRs data, most existing models ignore the time intervals between neighboring medical events. However, the time intervals are common and important in many healthcare applications. Therefore, a time-aware patient subtyping model [3] is proposed to take into account time intervals in patients' EHRs data. It is demonstrated that taking time intervals into account can significantly improve the model's performance.

Attention mechanism
There are all kinds of medical events (e.g., diagnoses and medications) in EHRs data, which includes redundant and useless information. Only the events related to some specific diseases are crucial to predict risk. Therefore, attention mechanism is introduced to automatically attend to the useful events [8,11,21].
The attention mechanism has been shown to be helpful in the natural language processing domain. Vaswani et. al. propose Transformer [22] for machine translation task. Transformer uses self-attention to capture the relations between input words inside a sentence. The selfattention mechanism is highly parallelizable and easy to train. This work adopts a self-attention mechanism to do clinical risk prediction tasks and simultaneously aims to find clinically significant patterns related to sepsis and mortality risk with self-attention.

Interpretability
In the clinical domain, models' interpretability could be more important than their performance. Black-box approaches, especially deep learning methods, are not trusted by doctors and therefore not applied to real clinical situations. It motivates a lot of work focused on the interpretability of risk predictive models. RETAIN [8] is the first work that can interpret why the model makes particular predictions. It utilizes two attention modules (i.e., visit-level and code-level attention) that detect influential visits and significant medical codes. The attention weights of events indicate their importance for clinical outputs.
Then RETAIN input the weighted average of each patient's events' embeddings to a fully connected layer to predict the risk, which loses temporal information (e.g., the visits occurring order in patients' EHRs data).
Thus RETAIN achieves limited performance. Inspired by RETAIN, Zhang et.al. [21] propose an interpretable model to predict the risk of heart failure (IFM). IFM presents a position attention layer to capture clinical events' order. However, IFM ignores the irregular time intervals between visits in patients' EHRs data. Both the studies aim to calculate events' contribution to clinical output risk, but ignore medical event patterns' importance. It is possible that when some clinical events occur simultaneously, it may lead to a sharp increase to risk while each event alone does not cause high risk. In this study, we adopt self-attention mechanism to capture clinical significant event patterns [23].

Methods
In this section, we give a detailed description of the proposed PAVE, which consists of four main parts. First, an embedding module represents medical events, variable values and the happening time as vectors. Then, a self-attention module is used to capture the pattern information between events. Next, a pattern attention module is followed to fuse all the pattern features, which are sent to a fully connected layer to predict the clinical outcomes. The framework of PAVE is shown in Fig. 1.

Problem definition and notation
The risk prediction task can be regarded as a binary classification problem. Given a sequence of medical events, the framework aims to predict if the patient will have a certain medical event (e.g., diagnosis codes, mortality) in the future.
A patient's EHRs data consist of two main parts: static information and dynamic information. Static information is his/her demographics, such as gender and age. We represent each patient's demographics as one-hot vectors. Patients' ages are divided into several age groups (e.g., 20-29, 30-39).
The dynamic information is his/her historical records, including diagnosis codes, medications, lab tests, vital signs (patients in ICU have vital sign data). Each diagnosis code is Boolean-value data and others are real-value data. There could be several diagnosis codes, many collections of lab tests and vital sign data in one visit. There are usually some missing values in some items of the lab test and vital signs in each collection.
Given a patient, his/her data are denoted as (x, ŷ ). The input data x includes the input demographics d and a sequence of n EHRs records, denoted as (e 1 , t 1 ), (e 2 , t 2 ), . . . , (e n , t n ) . For each event e i , its happening time is represented as t i . ŷ is the risk ground truth.

Embedding module
In this subsection, we present a new event embedding with the consideration of variable values, the corresponding happening time and patient demographics. As Shown in Fig. 2, the embedding module takes event (as well as the values), happening time, demographics as input and adopts three embedding layers to project them into vectors.

Time embedding
The first embedding layer is time embedding layer, which map the happening time t i into a vector v t i ∈ R k . t i is the interval time between the happening time of event e i and the last event time. The jth dimension of v t i is computed as: where t m is the maximum of time intervals, k denotes the dimension of v t i .

Value embedding
The second is medical event value embedding layer, which map each event e i and its value v i into a vector v e i ∈ R k . Given an event and its value, we map the event into a vector v e,e i via a fully connected layer. If the event value is Boolean value (e.g., diagnosis code), we directly use v e,e i as v e i . Otherwise for float value events (e.g., lab tests), given the value v i of event e i , the value embedding layer generate a vector v e,v i in the same way as time embedding layer. The jth dimension of v e,v i is computed as: where v min and v max are the minimum and maximum values of the corresponding variable, k denotes the dimension of v e,v i . Given v e,v i and v e,e i , a linear function is used to combine them to v e i : where W v , W e ∈ R k×k and b e ∈ R k are learnable parameters.

Demographic embedding
The third embedding layer is demographic embedding layer, which embeds d into a matrix v d ∈ R |d|×k . A Fig. 1 Framework of PAVE. Given a patient, the event embedding module takes his/her demographics (i.e., age and gender) and medical events plus occurring time (e 1 , t 1 ), (e 2 , t 2 ), . . . , (e n , t n ) as inputs and generates a sequence of embedding vectors q = {q 1 , q 2 , . . . , q n } ∈ R n×k . Then three fully connected layers are followed to map q to queries Q ∈ R n×k , keys K ∈ R n×k and values V ∈ R n×k . Next, a self-attention module is adopted to attend to meaningful patterns between medical events and output attention results P = {P 1 , P 2 , . . . , P n } ∈ R n×k , which are sent to a pattern attention module to generate the attention result h ∈ R k . Finally, a fully connected (FC) layer and Sigmoid layer are leveraged to output the clinical outcome risk , v e i , v t i are concatenated and a fully connected layer is followed. The output vector q i is the embedding vector for the event i demographic attention mechanism is leveraged to attend to the demographic information. where Given the embedding and attention results (i.e., v e i , v t i and v d,a i ), using a concatenation operation and a fully connected layer, the ith event and the patient's demographics are projected into an embedding vector q i ∈ R k .
where W q,e , W q,t , W q,d ∈ R k×k and b q ∈ R k are learnable parameters.

Self attention module
Given a patient, his/her sequence of final embeddings of events q = {q 1 , q 2 , . . . , q n } are input to self-attention module to capture useful patterns between related events. Three fully connected layers are used to map q into three matrices Q, K , V ∈ R n×k , which are queries, keys and values respectively. The self-attention output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key. Specifically, we compute the dot products of each query Q i with other keys K j and calculate the attention weight α ij with a softmax function. Obtaining the weight, the sum of query event's value and attention result of key events' values is output as the pattern attention outcome P ∈ R n×k . The ith dimension P is computed as follows: where W Q , W K , W V ∈ R k×k and b Q , b K , b V ∈ R k are learnable parameters. Given two events i and j, the product between query Q i and key K j represents their relevance β ij . A softmax layer is followed to generate the attention weight α ij . Finally, a soft attention layer is used to produce the pattern vector P j . The self-attention module can capture two-event patterns. By stacking more self-attention layer, PAVE also has the potential to capture more complex medical patterns with more events.

Pattern attention module
There are various patterns in each patient's EHRs data, only some are useful for risk prediction goal. Given the pattern embeddings P ∈ R n×k , a pattern attention mechanism is used to attend to the meaningful patterns.
where W p ∈ R k and b p ∈ R is learnable parameters. Given a medical event pattern i for a patient, a fully connected layer is adopted to compute its relevance θ i to the risk prediction task. Then a softmax layer is followed to compute the weights for different patterns. Finally a soft attention is used to combine various patterns and produce a vector h, which contains the patient's clinical risk information.

Objective function
A fully connected layer and sigmoid layer are followed to predict the risk probability: where W h ∈ R k and b h ∈ R are learnable parameters. The cross-entropy between ground truth ŷ and predicted result y is used to compute loss:

Interpretability
The interpretability is that PAVE can compute each pattern's contribution to the output. Given pattern (i, j), including event i and event j, the contribution C ij is calculated as follows:

Results and discussion
In order to evaluate the effectiveness of the proposed PAVE, we compare our model with some stateof-art methods on two real-world clinical datasets: publicly available MIMIC-III [14] and a proprietary L(y,ŷ) = −(ŷlog(y) + (1 −ŷ)log(1 − y)).
EHRs database. The experiments are conducted on two different tasks: sepsis onset prediction and mortality prediction.

Datasets
Both the datasets of sepsis prediction and mortality prediction tasks are from Intensive Care Unit (ICU).

Sepsis prediction
The first dataset is extracted from a real-world proprietary EHRs database. We use patients' demographics information and 27 kinds of time series features including vital signs and lab tests to predict sepsis onset after several hours. Sepsis is one of the leading causes of mortality in hospitalized patients. We follow the sepsis 2 definition [24]. The sepsis 2 patients must meet at least two of the following four SIRS criteria: • Body temperature > 38.0 or < 35.0 • Respiratory rate > 20 or PaCO 2 < 32 mmHg • Heart rate > 90/min • WBC > 12k or < 4k or Band > 10%.

Mortality prediction
The second dataset is publicly available dataset MIMIC-III [14]. We use patients' demographics and 8 vital signs data to predict the mortality in the coming hours. For each case patient (with sepsis 2 onset or mortality) on both datasets, 3 patients with the same age and gender are chosen as the controls. For both cases and controls, our model predicts whether the patients suffer from sepsis onset or mortality after a hold-off prediction window (e.g., 10, 8, 6, 4 h). PAVE and baselines take patients' observed variables during the last 48 h as inputs (the data in the hold-off windows are excluded). The statistics of the selected datasets are listed in Table 1. The selected variables are listed in Table 2.

Methods for comparison
To validate the performance of PAVE, we compare it with the following models, including three traditional machine learning methods and four deep learning methods. In order to demonstrate the effectiveness of the proposed time embedding and event embedding, we also implement three versions of PAVE.

Random forest (RF):
We represent each patient's demographics into a vector. For each variable, we extract the minimum and maximum value. The concatenation vectors of the values of patients are used to train the Random Forest model.

Logistic regression (LR):
We train the logistic regression model with the same vectors as random forest. The logistic regression is trained with five various solvers, including lbfgs, new-cg, liblinear, sag and saga. We choose the solver with the best performance in validation set.

Support vector machine (SVM):
We train the support vector machine model with the same vectors as random forest. The support vector machine is trained with four   [8] is the first work that tries to interpret model's disease risk prediction results with two attention modules. The attention modules generate weights for every medical event. The weights are helpful to analyze different events' contributions to the output risk. IFM: IFM [21] is an interpretable heart failure risk prediction model, which is also based on attention mechanism and leverages the attention weights to interpret the outputs. In this work, we modify the IFM to predict sepsis onset and mortality.

Implementation details
We implement all the baselines and our proposed PAVE models with PyTorch 0.4.1 1 and scikit-learn. 2 For the traditional machine learning approaches (i.e., LR, RF and SVM), a grid search is adopted to find the best parameter settings. For the deep learning approaches (i.e.,

Results of risk prediction
As is shown in Table 3, the proposed model PAVE outperforms all the baselines, which demonstrates the effectiveness of our model.
The deep learning approaches outperform the traditional machine-learning approaches that take vectors as inputs but not sequence data. Traditional machine-learning approaches' inputs lose the temporal information of EHR data, which are very important in the risk prediction tasks, while deep learning models are good at modeling temporal data. Thus, the deep learning baselines achieves better performance. Among the deep learning baselines, attention-based models (i.e., RETAIN and IFM) perform better than other models in the mortality prediction task, while LSTM and GRU perform better in the sepsis onset prediction task. We speculate that mortality is easier to predict based on several vital sign features, such as heart rate and respiratory rate in recent hours. Attention-based models do well in capturing important events and thus achieves better performance. Sepsis is a complex disease that is more difficult to be predicted than mortality. The prediction of sepsis onset is related to changes in patients' health states during a relatively longer period. LSTM and GRU are better at modeling the long time changes of the states, while RETAIN and IFM lose some temporal information with the attention mechanisms. In the clinical domain, models' interpretability could be more important than their performance. Thus, the interpretable risk prediction models (i.e., PAVE, RETAIN and IFM) are more suitable for real-world clinical applications. Compared with RETAIN and IFM, PAVE leverages attention mechanism to focus on important events, and incorporates time information with time embedding, so it outperform RETAIN and IFM by 1.5 percent and 3 percent for sepsis and mortality prediction tasks respectively. Among the three versions of the proposed model, PAVE −T performs worse than PAVE, which means that with the time embedding, PAVE can capture more time information of time intervals. PAVE also outperforms PAVE −E , which takes the imputed values as inputs, but not value embeddings. The imputation strategy may introduce bias and thus be harmful to the final risk prediction tasks.

Medical event pattern analysis
PAVE is able to analyze the patterns' contributions to the prediction. We compute each pattern's contribution to the risk of mortality for each patient according to Eq. (10). For each variable, their values are divided into five ranges. By comparing each item value to its normal range, the item value is mapped into three ranges (e.g., low, normal and high). Then the high-value range is divided into two parts (i.e., high and very high) by comparing the value to the median of all the high values. The low-value range is divided in the same way. We display the top 10 patterns with the highest average contribution rates among all the case patients to mortality (10-h mortality prediction) in Table 4. The patterns are verified by clinicians to be high-risk signals to mortality, which demonstrate PAVE can find useful patterns in the prediction tasks.
We conducted the experiments lots of times and found some patterns always have relatively high weights. For example, the weight of the pattern (very high temperature and very low respiratory rate) is always much higher than other random patterns, which is consistent with clinical knowledge that the patients with very high temperature and very low respiratory rate simultaneously have high risk of mortality.

Case study
We applied PAVE to predict the mortality risk of a patient from the test set, who suffered mortality after 10 h. We display the observed variables during the last 24 h in observation window in Fig. 3. RETAIN is also used to predict the mortality risk for comparison. Both PAVE and RETAIN accurately predict the patient's mortality after 10 h. In this case study, we mainly focus on the interpretability of the detected medical events or patterns with high contribution risks. The black stars in Fig. 3 represent observed abnormal values with high instancewise contribution risks generated by RETAIN, while the colored squares are medical event patterns detected by PAVE. In the case, PAVE found three patterns with high contribution risks: (1) high SysBP and high temperature in orange squares; (2) high heart rate and high temperature in red squares; (3) stable high heart rate and high respiratory rate in blue squares. The events sharing the same colors are detected patterns. Note that only the patterns with relatively high contribution risks are shown in the figure. The sizes of black stars and colored squares denote the corresponding values of contribution risks. Both the models successfully detect some crucial medical events related to high mortality risks, such as high heart rates and high temperature. PAVE focuses much more on the observed variables during the last 10 h in the observation window (e.g., the stable high heart rate and high respiratory rate in blue squares), while RETAIN attends to lots of earlier events but ignore the latter high heart rate and high respiratory rate in the last three collections. It means PAVE learn an knowledge that both the latter medical events and the abnormal values are more useful for accurate mortality prediction, while RETAIN only focuses on abnormal values. Moreover, when some crucial patterns (e.g., high heart rate and high respiratory rate in blue squares) appear, PAVE assigned more attention weights to the patterns than RETAIN (the colored squares have bigger size than the corresponding stars), which demonstrate that PAVE are effective for mining relative and important patterns, and pay more attention to the meaningful patterns.

Conclusion
In this work, we proposed PAVE, an interpretable pattern attention model with value embedding to predict disease risk. PAVE takes into account real-value medical events (e.g., lab tests and vital signs) by embedding the values into vectors, and therefore does not need to impute the missing values. Moreover, PAVE is based on attention mechanisms and the attention weights can be used to interpret the model's clinical outputs.
To the best of our knowledge, PAVE is the first interpretable deep learning model that can provide medical pattern-wise interpretability but not only instance-wise interpretability. Event patterns may cause a much higher risk than each single event in the pattern. We conducted expensive experiments on two real-world datasets and PAVE achieved better performance than state-of-art models. Moreover, the experimental results show that PAVE is able to detect lots of medical event patterns with high contribution rates to mortality and sepsis onset, which paves the way for interpretable clinical risk predictions.  Fig. 3 Case study. The figure shows the observed variables of a case patient during the last 24 h observation window (the hold-off window is 10 h). The black stars represent observed abnormal values with instance-wise contribution risks generated by RETAIN, while the colored squares are medical event patterns detected by PAVE. The events sharing the same colors are detected patterns. The sizes of black stars and colored squares denote the corresponding values of contribution risks. Note that only the events or patterns with relatively higher contribution risks are marked in the figure. PAVE found three high-risk patterns for the patient: (1) high SysBP and high temperature in orange squares; (2) high heart rate and high temperature in red squares; (3) stable high heart rate and high respiratory rate in blue squares