Skip to main content

An interpretable risk prediction model for healthcare with pattern attention

Abstract

Background

The availability of massive amount of data enables the possibility of clinical predictive tasks. Deep learning methods have achieved promising performance on the tasks. However, most existing methods suffer from three limitations: (1) There are lots of missing value for real value events, many methods impute the missing value and then train their models based on the imputed values, which may introduce imputation bias. The models’ performance is highly dependent on the imputation accuracy. (2) Lots of existing studies just take Boolean value medical events (e.g. diagnosis code) as inputs, but ignore real value medical events (e.g., lab tests and vital signs), which are more important for acute disease (e.g., sepsis) and mortality prediction. (3) Existing interpretable models can illustrate which medical events are conducive to the output results, but are not able to give contributions of patterns among medical events.

Methods

In this study, we propose a novel interpretable Pattern Attention model with Value Embedding (PAVE) to predict the risks of certain diseases. PAVE takes the embedding of various medical events, their values and the corresponding occurring time as inputs, leverage self-attention mechanism to attend to meaningful patterns among medical events for risk prediction tasks. Because only the observed values are embedded into vectors, we don’t need to impute the missing values and thus avoids the imputations bias. Moreover, the self-attention mechanism is helpful for the model interpretability, which means the proposed model can output which patterns cause high risks.

Results

We conduct sepsis onset prediction and mortality prediction experiments on a publicly available dataset MIMIC-III and our proprietary EHR dataset. The experimental results show that PAVE outperforms existing models. Moreover, by analyzing the self-attention weights, our model outputs meaningful medical event patterns related to mortality.

Conclusions

PAVE learns effective medical event representation by incorporating the values and occurring time, which can improve the risk prediction performance. Moreover, the presented self-attention mechanism can not only capture patients’ health state information, but also output the contributions of various medical event patterns, which pave the way for interpretable clinical risk predictions.

Availability

The code for this paper is available at: https://github.com/yinchangchang/PAVE.

Background

With the increased growth of Electronic Health Records (EHRs) both in volume and diversity during the last decades, it becomes possible to apply clinical predictive models to improve the quality of clinical care. EHRs are temporal sequence data and consist of diagnosis codes, medications, lab results, and vital signs. Patient health information contained in the massive EHRs is extremely useful in different tasks within the medical domain, such as risk prediction [1, 2], patient subtyping [3, 4], treatment effect estimation [5, 6], and patient similarity analysis [7]. In this paper, we focus on clinical risk prediction tasks. Most state-of-the-art clinical risk predictive models are based on deep learning, and trained in an end-to-end way. Recurrent Neural Network (RNN), a popular deep learning model for modeling sequences, has achieved good performance in clinical risk prediction tasks recently [8,9,10]. However, there are still some challenges in the field. (1) Most existing methods [11, 12] represent medical events as embedding vectors, which lose real value information of the medical events (e.g., lab tests and vital signs). (2) Lab tests are diagnosis-driven and therefore EHRs have lots of missing value for lab tests. Many methods [13] impute the missing value and then train their models based on the imputed values. The models’ performance is highly dependent on the imputation accuracy. (3) Existing interpretable models are only able to provide instance-wise variable importance (i.e., to compute each medical event’s contribution to the disease risks) rather than pattern-wise importance. It is possible that when some clinical events occur simultaneously, it may lead to a sharp increase to risk while each event alone does not cause high risk.

In this study, we propose a new interpretable Pattern Attention model with Value Embedding (PAVE), which is totally based on attention mechanism. For each patient, medical events, values (e.g., lab test and vital sign values) and their corresponding occurring time are represented as embedding vectors and projected to a medical semantic space. Then a self-attention layer is leveraged to capture the meaningful patterns among medical events. A pattern attention module is proposed to attend to the event patterns and produce an attention vector for each patient. Finally, we use a fully connected layer to predict a patient’s risk for future clinical outcomes. By analyzing the self-attention weights and pattern attention weights, our model is able to compute the contribution rates of various medical event patterns, thus paving the way for interpretable clinical risk predictions.

In order to demonstrate the effectiveness of the proposed PAVE, we compare our model against both traditional machine-learning methods (e.g., logistic regression, random forest) and recent deep-learning methods (e.g., RETAIN) on sepsis and mortality risk prediction tasks. We conducted experiments on both a publicly available MIMIC-III dataset [14] and our proprietary EHRs data. The experimental results show that PAVE outperforms all the baselines in both datasets and various settings, which demonstrates the effectiveness of the proposed model. Moreover, after PAVE is well trained, it is also able to find the EHRs event patterns with high contribution rates to high mortality risks. To highlight the handout of the proposed framework is as follows:

  • We propose a novel interpretable risk prediction model PAVE, which is based on a self-attention mechanism and achieves better performance than the baselines.

  • The presented self-attention mechanism can automatically capture meaningful patterns and is helpful to find the patterns related to high risks. To the best of our knowledge, this work is the first attempt to identify the contributions of patterns.

  • We propose a new value embedding that can map values into vectors, so we don’t need to impute the missing values.

  • Our medical event embedding module can take medical events’ occurring time into account.

Fig. 1
figure1

Framework of PAVE. Given a patient, the event embedding module takes his/her demographics (i.e., age and gender) and medical events plus occurring time \((e_1, t_1), (e_2, t_2) ,\ldots , (e_n, t_n)\) as inputs and generates a sequence of embedding vectors \(q=\{q_1, q_2,\ldots , q_n\} \in R^{n \times k}\). Then three fully connected layers are followed to map q to queries \(Q \in R^{n \times k}\), keys \(K \in R^{n \times k}\) and values \(V \in R^{n \times k}\). Next, a self-attention module is adopted to attend to meaningful patterns between medical events and output attention results \(P= \{P_1, P_2,\ldots , P_n\} \in R^{n \times k}\), which are sent to a pattern attention module to generate the attention result \(h \in R^k\). Finally, a fully connected (FC) layer and Sigmoid layer are leveraged to output the clinical outcome risk

Related work

Due to their promising performance in clinical risk prediction task, deep learning methods have attracted significant interest from healthcare researchers. In this section, we go through with the existing work related to deep learning models, including risk prediction, attention mechanism, and clinical models’ interpretability.

Risk prediction for healthcare

Extensive research has shown the potential of early prediction of the risk of diseases from Electronic Health Records (EHRs) data, which has tempted substantial attention [13, 15,16,17,18]. In this section, we mainly focus on Recurrent Neural Networks (RNN) based models. RNN can be used for patient subtyping [3], phenotyping [19], similarity measurement [7], and missing values imputation [13, 20], which are highly related to risk prediction tasks. For some RNN based approaches, the relationships between subsequent visits are usually not considered. To address the issue, Dipole [11] adopts attention mechanisms to capture the visits’ relations and therefore significantly improves the prediction accuracy.

When preprocessing the EHRs data, most existing models ignore the time intervals between neighboring medical events. However, the time intervals are common and important in many healthcare applications. Therefore, a time-aware patient subtyping model [3] is proposed to take into account time intervals in patients’ EHRs data. It is demonstrated that taking time intervals into account can significantly improve the model’s performance.

Attention mechanism

There are all kinds of medical events (e.g., diagnoses and medications) in EHRs data, which includes redundant and useless information. Only the events related to some specific diseases are crucial to predict risk. Therefore, attention mechanism is introduced to automatically attend to the useful events [8, 11, 21].

The attention mechanism has been shown to be helpful in the natural language processing domain. Vaswani et. al. propose Transformer [22] for machine translation task. Transformer uses self-attention to capture the relations between input words inside a sentence. The self-attention mechanism is highly parallelizable and easy to train. This work adopts a self-attention mechanism to do clinical risk prediction tasks and simultaneously aims to find clinically significant patterns related to sepsis and mortality risk with self-attention.

Interpretability

In the clinical domain, models’ interpretability could be more important than their performance. Black-box approaches, especially deep learning methods, are not trusted by doctors and therefore not applied to real clinical situations. It motivates a lot of work focused on the interpretability of risk predictive models. RETAIN [8] is the first work that can interpret why the model makes particular predictions. It utilizes two attention modules (i.e., visit-level and code-level attention) that detect influential visits and significant medical codes. The attention weights of events indicate their importance for clinical outputs.

Then RETAIN input the weighted average of each patient’s events’ embeddings to a fully connected layer to predict the risk, which loses temporal information (e.g., the visits occurring order in patients’ EHRs data). Thus RETAIN achieves limited performance. Inspired by RETAIN, Zhang et.al. [21] propose an interpretable model to predict the risk of heart failure (IFM). IFM presents a position attention layer to capture clinical events’ order. However, IFM ignores the irregular time intervals between visits in patients’ EHRs data. Both the studies aim to calculate events’ contribution to clinical output risk, but ignore medical event patterns’ importance. It is possible that when some clinical events occur simultaneously, it may lead to a sharp increase to risk while each event alone does not cause high risk. In this study, we adopt self-attention mechanism to capture clinical significant event patterns [23].

Methods

In this section, we give a detailed description of the proposed PAVE, which consists of four main parts. First, an embedding module represents medical events, variable values and the happening time as vectors. Then, a self-attention module is used to capture the pattern information between events. Next, a pattern attention module is followed to fuse all the pattern features, which are sent to a fully connected layer to predict the clinical outcomes. The framework of PAVE is shown in Fig. 1.

Problem definition and notation

The risk prediction task can be regarded as a binary classification problem. Given a sequence of medical events, the framework aims to predict if the patient will have a certain medical event (e.g., diagnosis codes, mortality) in the future.

A patient’s EHRs data consist of two main parts: static information and dynamic information. Static information is his/her demographics, such as gender and age. We represent each patient’s demographics as one-hot vectors. Patients’ ages are divided into several age groups (e.g., 20–29, 30–39).

The dynamic information is his/her historical records, including diagnosis codes, medications, lab tests, vital signs (patients in ICU have vital sign data). Each diagnosis code is Boolean-value data and others are real-value data. There could be several diagnosis codes, many collections of lab tests and vital sign data in one visit. There are usually some missing values in some items of the lab test and vital signs in each collection.

Given a patient, his/her data are denoted as (x, \({\hat{y}}\)). The input data x includes the input demographics d and a sequence of n EHRs records, denoted as \((e_1, t_1), (e_2, t_2),\ldots , (e_n, t_n)\). For each event \(e_i\), its happening time is represented as \(t_i\). \({\hat{y}}\) is the risk ground truth.

Fig. 2
figure2

Embedding module. For each event i, we embed its happening time into vector \(v_i^t\), and its event value into vector \(v_i^e\). Then \(v_i^e\) is used to attend to the patient’s demographic information. The attention result \(v_i^j\), \(v_i^e\), \(v_i^t\) are concatenated and a fully connected layer is followed. The output vector \(q_i\) is the embedding vector for the event i

Embedding module

In this subsection, we present a new event embedding with the consideration of variable values, the corresponding happening time and patient demographics. As Shown in Fig. 2, the embedding module takes event (as well as the values), happening time, demographics as input and adopts three embedding layers to project them into vectors.

Time embedding

The first embedding layer is time embedding layer, which map the happening time \(t_i\) into a vector \(v_i^t \in R^k\). \(t_i\) is the interval time between the happening time of event \(e_i\) and the last event time. The \(j{\mathrm{th}}\) dimension of \(v_i^t\) is computed as:

$$\begin{aligned} v_{i, j}^t = {\left\{ \begin{array}{ll} \sin \left( \frac{t_i * j}{t_m * k}\right) , &\quad {\text {if j is even}} \\ \cos \left( \frac{t_i * j}{t_m * k}\right) , &\quad {\text {if j is odd}}, \end{array}\right. } \end{aligned}$$
(1)

where \(t_m\) is the maximum of time intervals, k denotes the dimension of \(v_{i}^t\).

Value embedding

The second is medical event value embedding layer, which map each event \(e_i\) and its value \(v_i\) into a vector \(v_i^e \in R^k\). Given an event and its value, we map the event into a vector \(v_i^{e,e}\) via a fully connected layer. If the event value is Boolean value (e.g., diagnosis code), we directly use \(v_i^{e,e}\) as \(v_i^e\). Otherwise for float value events (e.g., lab tests), given the value \(v_i\) of event \(e_i\), the value embedding layer generate a vector \(v_i^{e,v}\) in the same way as time embedding layer. The \(j{\mathrm{th}}\) dimension of \(v_i^{e,v}\) is computed as:

$$\begin{aligned} v_{i, j}^{e,v} = {\left\{ \begin{array}{ll} \sin \left( \frac{(v_i - v_{min}) * j}{(v_{max} - v_{min}) * k}\right) , &\quad {\text {if j is even}} \\ \cos \left( \frac{(v_i - v_{min}) * j}{(v_{max} - v_{min}) * k}\right) , &\quad {\text {if j is odd,}} \end{array}\right. } \end{aligned}$$
(2)

where \(v_{min}\) and \(v_{max}\) are the minimum and maximum values of the corresponding variable, k denotes the dimension of \(v_i^{e,v}\). Given \(v_i^{e,v}\) and \(v_i^{e,e}\), a linear function is used to combine them to \(v_i^e\):

$$\begin{aligned} v_i^e = v_i^{e,v} W_v + v_i^{e,e} W_e + b_e, \end{aligned}$$
(3)

where \(W_v, W_e \in R^{k \times k}\) and \(b_e \in R^k\) are learnable parameters.

Demographic embedding

The third embedding layer is demographic embedding layer, which embeds d into a matrix \(v^d \in R^{|d| \times k}\). A demographic attention mechanism is leveraged to attend to the demographic information.

$$\begin{aligned} v_i^{d,a}& = \sum _{j=1}^{|d|} v_j^d * \alpha _j^{d,i} \nonumber \\ \alpha _j^{d,i}& = \frac{exp(\beta _j^{d,i}) }{ \sum _{j=u}^{|d|} exp(\beta _u^{d,i})}\nonumber \\ \beta _j^{d,i}& = v_j^d W_{v,d} + v_i^e W_{v,e}, \end{aligned}$$
(4)

where \(W_{v,d}, W_{v, e} \in R^k\) are learnable parameter, \(v_j^d \in R^k\) denotes the \(j{\mathrm{th}}\) dimension of \(v^d\), \(v_i^{d,a} \in R^k\) is the demographic attention result.

Given the embedding and attention results (i.e., \(v_i^e\), \(v_i^t\) and \(v_i^{d,a}\)), using a concatenation operation and a fully connected layer, the \(i{\mathrm{th}}\) event and the patient’s demographics are projected into an embedding vector \(q_i \in R^k\).

$$\begin{aligned} q_i = v_i^{e} W_{q,e} + v_i^{t} W_{q,t} + v_i^{d,a} W_{q,d} + b_q, \end{aligned}$$
(5)

where \(W_{q,e}, W_{q,t}, W_{q,d} \in R^{k \times k}\) and \(b_q \in R^k\) are learnable parameters.

Self attention module

Given a patient, his/her sequence of final embeddings of events \(q = \{q_1, q_2,\ldots , q_n\}\) are input to self-attention module to capture useful patterns between related events. Three fully connected layers are used to map q into three matrices \(Q, K, V \in R^{n \times k}\), which are queries, keys and values respectively. The self-attention output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key. Specifically, we compute the dot products of each query \(Q_i\) with other keys \(K_j\) and calculate the attention weight \(\alpha _{ij}\) with a softmax function. Obtaining the weight, the sum of query event’s value and attention result of key events’ values is output as the pattern attention outcome \(P \in R^{n \times k}\). The \(i{\mathrm{th}}\) dimension P is computed as follows:

$$\begin{aligned} P_i& = V_i + \sum _j \alpha _{ij} V_j \nonumber \\ \alpha _{ij}& = \frac{\exp (\beta _{ij})}{ \sum _l \exp (\beta _{il}) } \nonumber \\ \beta _{ij}& = Q_i K_j^T \nonumber \\ Q_i& = q_i W_Q + b_Q \nonumber \\ K_i& = q_i W_K + b_K \nonumber \\ V_i& = q_i W_V + b_V, \end{aligned}$$
(6)

where \(W_Q, W_K, W_V \in R^{k \times k}\) and \(b_Q, b_K, b_V \in R^{k}\) are learnable parameters. Given two events i and j, the product between query \(Q_i\) and key \(K_j\) represents their relevance \(\beta _{ij}\). A softmax layer is followed to generate the attention weight \(\alpha _{ij}\). Finally, a soft attention layer is used to produce the pattern vector \(P_j\). The self-attention module can capture two-event patterns. By stacking more self-attention layer, PAVE also has the potential to capture more complex medical patterns with more events.

Pattern attention module

There are various patterns in each patient’s EHRs data, only some are useful for risk prediction goal. Given the pattern embeddings \(P \in R^{n \times k}\), a pattern attention mechanism is used to attend to the meaningful patterns.

$$\begin{aligned} h& = \sum _i \gamma _i P_i\nonumber \\ \gamma _i& = \frac{exp (\theta _i)}{ \sum _j exp(\theta _j)} \nonumber \\ \theta _i& = P_i W_p + b_p, \end{aligned}$$
(7)

where \(W_p \in R^{k}\) and \(b_p \in R\) is learnable parameters. Given a medical event pattern i for a patient, a fully connected layer is adopted to compute its relevance \(\theta _i\) to the risk prediction task. Then a softmax layer is followed to compute the weights for different patterns. Finally a soft attention is used to combine various patterns and produce a vector h, which contains the patient’s clinical risk information.

Objective function

A fully connected layer and sigmoid layer are followed to predict the risk probability:

$$\begin{aligned} y = sigmoid(h W_h + b_h), \end{aligned}$$
(8)

where \(W_h \in R^k\) and \(b_h \in R\) are learnable parameters. The cross-entropy between ground truth \({\hat{y}}\) and predicted result y is used to compute loss:

$$\begin{aligned} L(y, {\hat{y}}) = -({\hat{y}} log(y) + (1 - {\hat{y}}) log(1 - y)). \end{aligned}$$
(9)

Interpretability

The interpretability is that PAVE can compute each pattern’s contribution to the output. Given pattern (i, j), including event i and event j, the contribution \(C_{ij}\) is calculated as follows:

$$\begin{aligned} C_{ij} = \gamma _i \alpha _{ij}. \end{aligned}$$
(10)

Results and discussion

In order to evaluate the effectiveness of the proposed PAVE, we compare our model with some state-of-art methods on two real-world clinical datasets: publicly available MIMIC-III [14] and a proprietary EHRs database. The experiments are conducted on two different tasks: sepsis onset prediction and mortality prediction.

Datasets

Both the datasets of sepsis prediction and mortality prediction tasks are from Intensive Care Unit (ICU).

Sepsis prediction

The first dataset is extracted from a real-world proprietary EHRs database. We use patients’ demographics information and 27 kinds of time series features including vital signs and lab tests to predict sepsis onset after several hours. Sepsis is one of the leading causes of mortality in hospitalized patients. We follow the sepsis 2 definition [24]. The sepsis 2 patients must meet at least two of the following four SIRS criteria:

  • Body temperature > 38.0 or < 35.0

  • Respiratory rate > 20 or PaCO2 < 32 mmHg

  • Heart rate > 90/min

  • WBC > 12k or < 4k or Band > 10%.

Mortality prediction

The second dataset is publicly available dataset MIMIC-III [14]. We use patients’ demographics and 8 vital signs data to predict the mortality in the coming hours. For each case patient (with sepsis 2 onset or mortality) on both datasets, 3 patients with the same age and gender are chosen as the controls. For both cases and controls, our model predicts whether the patients suffer from sepsis onset or mortality after a hold-off prediction window (e.g., 10, 8, 6, 4 h). PAVE and baselines take patients’ observed variables during the last 48 h as inputs (the data in the hold-off windows are excluded). The statistics of the selected datasets are listed in Table 1. The selected variables are listed in Table 2.

Table 1 Statistics of datasets
Table 2 Selected variables used for sepsis onset and mortality prediction
Table 3 AUROC mean ± std on sepsis and mortality prediction

Methods for comparison

To validate the performance of PAVE, we compare it with the following models, including three traditional machine learning methods and four deep learning methods. In order to demonstrate the effectiveness of the proposed time embedding and event embedding, we also implement three versions of PAVE.

Random forest (RF): We represent each patient’s demographics into a vector. For each variable, we extract the minimum and maximum value. The concatenation vectors of the values of patients are used to train the Random Forest model.

Logistic regression (LR): We train the logistic regression model with the same vectors as random forest. The logistic regression is trained with five various solvers, including lbfgs, new-cg, liblinear, sag and saga. We choose the solver with the best performance in validation set.

Support vector machine (SVM): We train the support vector machine model with the same vectors as random forest. The support vector machine is trained with four different kernels, including poly, rbf, linear and sigmoid. The kernel with the best performance in the validation set is used to predict the risk in the test set.

GRU and LSTM: GRU [25] and LSTM [26] are classical RNN based models, which both introduce various gates to improve RNN’s performance.

RETAIN: The REverse Time AttentIoN model (RETAIN) [8] is the first work that tries to interpret model’s disease risk prediction results with two attention modules. The attention modules generate weights for every medical event. The weights are helpful to analyze different events’ contributions to the output risk.

IFM: IFM [21] is an interpretable heart failure risk prediction model, which is also based on attention mechanism and leverages the attention weights to interpret the outputs. In this work, we modify the IFM to predict sepsis onset and mortality.

\(\mathbf{PAVE }^{-T}\): \({\text {PAVE}}^{-T}\) removes the time embedding module when predicting patient risks.

\(\mathbf{PAVE }^{-V}\): \({\text {PAVE}}^{-V}\) removes the variable value embedding. The method prefills the missing values with mean values and takes the prefilled values as inputs but not the value embeddings.

PAVE: PAVE is the main version of the proposed model.

Implementation details

We implement all the baselines and our proposed PAVE models with PyTorch 0.4.1Footnote 1 and scikit-learn.Footnote 2 For the traditional machine learning approaches (i.e., LR, RF and SVM), a grid search is adopted to find the best parameter settings. For the deep learning approaches (i.e., GRU, LSTM, RETAIN, IFM and PAVE), we use Adam optimizer with a mini-batch of 64 patients and train on 1 GPU (TITAN XP) for 50 epochs, with a learning rate of 0.0001. We randomly divide the datasets into 10 sets. All the experiment results are averaged from tenfold cross-validation, in which 7 sets are used for training every time, 1 set for validation and 2 sets for test. The validation sets are used to determine the best values of parameters in the training iterations. We use the area under the receiver operating characteristic curve (AUROC) in the test sets as a measure for comparing the performance of all the methods in two datasets. The dimensions of embedding and hidden vectors used in the deep-learning baselines and proposed PAVE are set as 512. We only use 1 layer of self-attention operation for PAVE to capture two-event patterns. The numbers of the trainable parameters of GRU, LSTM, RETAIN, IFM and PAVE are about 3.6 M, 4.4 M, 8.4 M, 1.2 M and 1.9 M respectively.

Results of risk prediction

As is shown in Table 3, the proposed model PAVE outperforms all the baselines, which demonstrates the effectiveness of our model.

The deep learning approaches outperform the traditional machine-learning approaches that take vectors as inputs but not sequence data. Traditional machine-learning approaches’ inputs lose the temporal information of EHR data, which are very important in the risk prediction tasks, while deep learning models are good at modeling temporal data. Thus, the deep learning baselines achieves better performance. Among the deep learning baselines, attention-based models (i.e., RETAIN and IFM) perform better than other models in the mortality prediction task, while LSTM and GRU perform better in the sepsis onset prediction task. We speculate that mortality is easier to predict based on several vital sign features, such as heart rate and respiratory rate in recent hours. Attention-based models do well in capturing important events and thus achieves better performance. Sepsis is a complex disease that is more difficult to be predicted than mortality. The prediction of sepsis onset is related to changes in patients’ health states during a relatively longer period. LSTM and GRU are better at modeling the long time changes of the states, while RETAIN and IFM lose some temporal information with the attention mechanisms. In the clinical domain, models’ interpretability could be more important than their performance. Thus, the interpretable risk prediction models (i.e., PAVE, RETAIN and IFM) are more suitable for real-world clinical applications. Compared with RETAIN and IFM, PAVE leverages attention mechanism to focus on important events, and incorporates time information with time embedding, so it outperform RETAIN and IFM by 1.5 percent and 3 percent for sepsis and mortality prediction tasks respectively.

Among the three versions of the proposed model, \({\text {PAVE}}^{-T}\) performs worse than PAVE, which means that with the time embedding, PAVE can capture more time information of time intervals. PAVE also outperforms \({\text {PAVE}}^{-E}\), which takes the imputed values as inputs, but not value embeddings. The imputation strategy may introduce bias and thus be harmful to the final risk prediction tasks.

Table 4 Top 10 patterns with the highest average contribution rates (AVG-CR) to mortality

Medical event pattern analysis

PAVE is able to analyze the patterns’ contributions to the prediction. We compute each pattern’s contribution to the risk of mortality for each patient according to Eq. (10). For each variable, their values are divided into five ranges. By comparing each item value to its normal range, the item value is mapped into three ranges (e.g., low, normal and high). Then the high-value range is divided into two parts (i.e., high and very high) by comparing the value to the median of all the high values. The low-value range is divided in the same way. We display the top 10 patterns with the highest average contribution rates among all the case patients to mortality (10-h mortality prediction) in Table 4. The patterns are verified by clinicians to be high-risk signals to mortality, which demonstrate PAVE can find useful patterns in the prediction tasks.

We conducted the experiments lots of times and found some patterns always have relatively high weights. For example, the weight of the pattern (very high temperature and very low respiratory rate) is always much higher than other random patterns, which is consistent with clinical knowledge that the patients with very high temperature and very low respiratory rate simultaneously have high risk of mortality.

Fig. 3
figure3

Case study. The figure shows the observed variables of a case patient during the last 24 h observation window (the hold-off window is 10 h). The black stars represent observed abnormal values with instance-wise contribution risks generated by RETAIN, while the colored squares are medical event patterns detected by PAVE. The events sharing the same colors are detected patterns. The sizes of black stars and colored squares denote the corresponding values of contribution risks. Note that only the events or patterns with relatively higher contribution risks are marked in the figure. PAVE found three high-risk patterns for the patient: (1) high SysBP and high temperature in orange squares; (2) high heart rate and high temperature in red squares; (3) stable high heart rate and high respiratory rate in blue squares

Case study

We applied PAVE to predict the mortality risk of a patient from the test set, who suffered mortality after 10 h. We display the observed variables during the last 24 h in observation window in Fig. 3. RETAIN is also used to predict the mortality risk for comparison. Both PAVE and RETAIN accurately predict the patient’s mortality after 10 h. In this case study, we mainly focus on the interpretability of the detected medical events or patterns with high contribution risks. The black stars in Fig. 3 represent observed abnormal values with high instance-wise contribution risks generated by RETAIN, while the colored squares are medical event patterns detected by PAVE. In the case, PAVE found three patterns with high contribution risks: (1) high SysBP and high temperature in orange squares; (2) high heart rate and high temperature in red squares; (3) stable high heart rate and high respiratory rate in blue squares. The events sharing the same colors are detected patterns. Note that only the patterns with relatively high contribution risks are shown in the figure. The sizes of black stars and colored squares denote the corresponding values of contribution risks. Both the models successfully detect some crucial medical events related to high mortality risks, such as high heart rates and high temperature. PAVE focuses much more on the observed variables during the last 10 h in the observation window (e.g., the stable high heart rate and high respiratory rate in blue squares), while RETAIN attends to lots of earlier events but ignore the latter high heart rate and high respiratory rate in the last three collections. It means PAVE learn an knowledge that both the latter medical events and the abnormal values are more useful for accurate mortality prediction, while RETAIN only focuses on abnormal values. Moreover, when some crucial patterns (e.g., high heart rate and high respiratory rate in blue squares) appear, PAVE assigned more attention weights to the patterns than RETAIN (the colored squares have bigger size than the corresponding stars), which demonstrate that PAVE are effective for mining relative and important patterns, and pay more attention to the meaningful patterns.

Conclusion

In this work, we proposed PAVE, an interpretable pattern attention model with value embedding to predict disease risk. PAVE takes into account real-value medical events (e.g., lab tests and vital signs) by embedding the values into vectors, and therefore does not need to impute the missing values. Moreover, PAVE is based on attention mechanisms and the attention weights can be used to interpret the model’s clinical outputs. To the best of our knowledge, PAVE is the first interpretable deep learning model that can provide medical pattern-wise interpretability but not only instance-wise interpretability. Event patterns may cause a much higher risk than each single event in the pattern. We conducted expensive experiments on two real-world datasets and PAVE achieved better performance than state-of-art models. Moreover, the experimental results show that PAVE is able to detect lots of medical event patterns with high contribution rates to mortality and sepsis onset, which paves the way for interpretable clinical risk predictions.

Availability of data and materials

MIMIC-III database analyzed in the study is available on PhysioNet repository. The source code is provided for reproducing and is available at https://github.com/yinchangchang/PAVE.

Notes

  1. 1.

    https://pytorch.org/.

  2. 2.

    https://scikit-learn.org/stable/.

Abbreviations

EHR:

Electronic health record

PAVE:

Pattern attention model with value embedding

FC:

Fully connected layer

MIMIC-III:

Medical information mart for intensive care III

LR:

Logistic regression

RF:

Random forest

SVM:

Support vector machine

RNN:

Recurrent neural network

LSTM:

Long short-term memory

GRU:

Gate recurrent unit

RETAIN:

Reverse time attention model

IFM:

An interpretable fast model

AUROC:

Area under receiver operating characteristic curve

AVG:

Average

CR:

Contribution rate

No.:

Number

CO2 :

Carbon dioxide

FiO2 :

Fraction of inspired oxygen

MAP:

Mean arterial pressure

RBC:

Red blood cell count

SPO2 :

Oxygen saturation

WBC:

White blood cell count

SysBP:

Systolic blood pressure

DiasBP:

Diastolic blood pressure

MeanBP:

Mean blood pressure

References

  1. 1.

    Cheng Y, Wang F, Zhang P, Hu J. Risk prediction with electronic health records: a deep learning approach. In: Proceedings of the 2016 SIAM international conference on data mining. London: SIAM; 2016. p. 432–40.

  2. 2.

    Zhang X, Qian B, Li Y, et al. KnowRisk: an interpretable knowledge-guided model for disease risk prediction. In: 2019 IEEE international conference on data mining (ICDM). New York: IEEE; 2019. p. 1492–7.

  3. 3.

    Baytas IM, Xiao C, Zhang X, et al. Patient subtyping via time-aware lstm networks. In: Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining; 2017. p. 65–74.

  4. 4.

    Yin C, Liu R, Zhang D, Zhang P. Identifying sepsis subphenotypes via time-aware multi-modal auto-encoder. In: Proceedings of the 26th ACM SIGKDD international conference on knowledge discovery & data mining; 2020. p. 862–72.

  5. 5.

    Glass TA, Goodman SN, Hernán MA, Samet JM. Causal inference in public health. Annu Rev Public Health. 2013;34:61–75.

    PubMed  PubMed Central  Article  Google Scholar 

  6. 6.

    Liu R, Yin C, Zhang P. Estimating individual treatment effects with time-varying confounders. arXiv preprint; 2020. arXiv:200813620.

  7. 7.

    Zhu Z, Yin C, Qian B, et al. Measuring patient similarities via a deep architecture with medical concept embedding. In: 2016 IEEE 16th international conference on data mining (ICDM). London: IEEE; 2016. p. 749–58.

  8. 8.

    Choi E, Bahadori MT, Sun J, et al. Retain: an interpretable predictive model for healthcare using reverse time attention mechanism. In: Advances in neural information processing systems; 2016. p. 3504–12.

  9. 9.

    Choi E, Bahadori MT, Song L, et al. GRAM: graph-based attention model for healthcare representation learning. In: Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining; 2017. p. 787–95.

  10. 10.

    Ma F, You Q, Xiao H, et al. KAME: knowledge-based attention model for diagnosis prediction in healthcare; 2018. p. 743–52.

  11. 11.

    Ma F, Chitta R, Zhou J, et al. Dipole: Diagnosis prediction in healthcare via attention-based bidirectional recurrent neural networks. In: Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining; 2017. p. 1903–11.

  12. 12.

    Yin C, Qian B, Cao S, Li X, Wei J, Zheng Q, et al. Deep similarity-based batch mode active learning with exploration–exploitation. In: 2017 IEEE international conference on data mining (ICDM). IEEE; 2017. p. 575–84.

  13. 13.

    Che Z, Purushotham S, Cho K, et al. Recurrent neural networks for multivariate time series with missing values. Sci Rep. 2018;8(1):1–12.

    Article  CAS  Google Scholar 

  14. 14.

    Johnson AE, Pollard TJ, Shen L, et al. MIMIC-III, a freely accessible critical care database. Sci Data. 2016;3(1):1–9.

    Article  CAS  Google Scholar 

  15. 15.

    Razavian N, Marcus J, Sontag D. Multi-task prediction of disease onsets from longitudinal laboratory tests. In: Machine learning for healthcare conference; 2016. p. 73–100.

  16. 16.

    Ma F, Gao J, Suo Q, You Q, Zhou J, Zhang A. Risk prediction on electronic health records with prior medical knowledge. In: Proceedings of the 24th ACM SIGKDD international conference on knowledge discovery and data mining; 2018. p. 1910–9.

  17. 17.

    Choi E, Schuetz A, Stewart WF, Sun J. Using recurrent neural network models for early detection of heart failure onset. J Am Med Inform Assoc. 2017;24(2):361–70.

    PubMed  Article  Google Scholar 

  18. 18.

    Beeksma M, Verberne S, Van den Bosch A, et al. Predicting life expectancy with a long short-term memory recurrent neural network using electronic medical records. BMC Med Inform Decis Mak. 2019;19(1):36.

    PubMed  PubMed Central  Article  Google Scholar 

  19. 19.

    Che Z, Kale D, Li W, Bahadori MT, Liu Y. Deep computational phenotyping. In: Proceedings of the 21st ACM SIGKDD international conference on knowledge discovery and data mining; 2015. p. 507–16.

  20. 20.

    Lipton ZC, Kale D, Wetzel R. Directly modeling missing data in sequences with RNNS: improved classification of clinical time series. In: Machine learning for healthcare conference; 2016. p. 253–70.

  21. 21.

    Zhang X, Qian B, Li X, Wei J, Zheng Y, Song L, et al. An interpretable fast model for predicting the risk of heart failure. London: SIAM; 2019. p. 576–84.

    Google Scholar 

  22. 22.

    Vaswani A, Shazeer N, Parmar N, Uszkoreit J, Jones L, Gomez AN, et al. Attention is all you need. In: Advances in neural information processing systems; 2017. p. 5998–6008.

  23. 23.

    Gligorijevic D, Stojanovic J, Satz W, Stojkovic I, Schreyer K, Del Portal D, et al. Deep attention model for triage of emergency department patients. In: Proceedings of the 2018 SIAM international conference on data mining. London: SIAM; 2018. p. 297–305.

  24. 24.

    Levy M. 2001 SCCM/ESICM/ATS/SIS international sepsis definitions conference. Crit Care Med. 2003;31:1250–6.

    PubMed  Article  Google Scholar 

  25. 25.

    Cho K, van Merrienboer B, Bahdanau D, Bengio Y. On the properties of neural machine translation: encoder–decoder approaches; 2014.

  26. 26.

    Hochreiter S, Schmidhuber J. Long short-term memory. Neural Comput. 1997;9(8):735–1780.

    Article  Google Scholar 

Download references

Acknowledgements

Not applicable.

About this supplement

This article has been published as part of BMC Medical Informatics and Decision Making Volume 20 Supplement 11 2020: Informatics and machine learning methods for health applications. The full contents of the supplement are available at https://bmcmedinformdecismak.biomedcentral.com/articles/supplements/volume-20-supplement-11.

Funding

This work is supported in part by “China Northwest Cohort Study” under the National Key Research and Development Program of China with grant number 2018YFC130078 (for SK and BQ) and “Multi-model Based Patient Similarity Learning for Medical Data Modelling and Learning” under National Natural Science Foundation of China General Program with grant number 61672420 (for SK and BQ). This Project is supported in part by The Ohio State University (for CY and PZ). Publication costs are funded by The Ohio State University.

Author information

Affiliations

Authors

Contributions

BQ and PZ conceived the project. SK and CY developed the method and conducted the experiments. SK, CY, BQ, and PZ wrote the manuscript. All authors read and approved the final manuscript.

Corresponding author

Correspondence to Ping Zhang.

Ethics declarations

Ethics approval and consent to participate

The proprietary EHR data used in this study was extracted from Xi’an Jiaotong University and the study has been approved by Xi’an Jiaotong University. Requirement for individual patient consent was waived since all protected health information was de-identified. De-identification was performed in compliance with Health Insurance Portability and Accountability Act (HIPAA) standards. Deletion of protected health information (PHI) from structured data sources (e.g., database fields that provide patient name or date of birth) was straightforward. Research use of publicly available MIMIC-III has been approved by the Institutional Review Board of Beth Israel Deaconess Medical Center and Massachusetts Institute of Technology [14].

Consent for publication

Not applicable.

Competing interests

PZ is the member of the editorial board of BMC Medical Informatics and Decision Making. The authors declare that they have no other competing interests.

Additional information

Publisher's note

Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.

Rights and permissions

Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if changes were made. The images or other third party material in this article are included in the article's Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article's Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by/4.0/. The Creative Commons Public Domain Dedication waiver (http://creativecommons.org/publicdomain/zero/1.0/) applies to the data made available in this article, unless otherwise stated in a credit line to the data.

Reprints and Permissions

About this article

Verify currency and authenticity via CrossMark

Cite this article

Kamal, S.A., Yin, C., Qian, B. et al. An interpretable risk prediction model for healthcare with pattern attention. BMC Med Inform Decis Mak 20, 307 (2020). https://doi.org/10.1186/s12911-020-01331-7

Download citation

Keywords

  • EHR
  • Risk prediction
  • Self-attention
  • Interpretability