Interpretable time-aware and co-occurrence-aware network for medical prediction

Background Disease prediction based on electronic health records (EHRs) is essential for personalized healthcare. But it’s hard due to the special data structure and the interpretability requirement of methods. The structure of EHR is hierarchical: each patient has a sequence of admissions, and each admission has some co-occurrence diagnoses. However, the existing methods only partially model these characteristics and lack the interpretation for non-specialists. Methods This work proposes a time-aware and co-occurrence-aware deep learning network (TCoN), which is not only suitable for EHR data structure but also interpretable: the co-occurrence-aware self-attention (CS-attention) mechanism and time-aware gated recurrent unit (T-GRU) can model multilevel relations; the interpretation path and the diagnosis graph can make the result interpretable. Results The method is tested on a real-world dataset for mortality prediction, readmission prediction, disease prediction, and next diagnoses prediction. Experimental results show that TCoN is better than baselines with 2.01% higher accuracy. Meanwhile, the method can give the interpretation of causal relationships and the diagnosis graph of each patient. Conclusions This work proposes a novel model—TCoN. It is an interpretable and effective deep learning method, that can model the hierarchical medical structure and predict medical events. The experiments show that it outperforms all state-of-the-art methods. Future work can apply the graph embedding technology based on more knowledge data such as doctor notes.


Background
Electronic Health Records (EHRs) are increasingly popular and widely used in hospitals for better healthcare management. A typical EHR dataset consists of much patient information, including demographic information and medical information. The medical information is an irregular hierarchical patient-visit-code (patientadmission-diagnosis) form, shown in Fig. 1a: (1) Each patient has many visit records as he/she may go to see a doctor many times. The visit records have corresponding time stamps and form a sequence; (2) Each visit contains many codes, which are usually disease diagnoses. The codes have the co-occurrence relation without order. For example, in a patient record, the chronic kidney disease is recorded after a cold record, but we can't conclude that the patient didn't have chronic kidney disease before he caught a cold. Two diagnoses have an uncertain time Open Access *Correspondence: sun_chenxi@pku.edu.cn 1 School of Electronics Engineering and Computer Science, Peking University, No. 5 Yiheyuan Road, Beijing 100871, People's Republic of China Full list of author information is available at the end of the article relation. We call such issues as the co-occurrence relation, such as complication, causation, and continuity. Thus, EHR has both the time relation and the co-occurrence relation.
Medical tasks such as disease prediction [1][2][3], concept representation [4,5], and patient typing [6][7][8] are essential for personalized healthcare and medical research. Nevertheless, the tasks are challenging for physicians, considering the complex patient states, the amount of diagnosis, and the real-time requirement. Thus, a datadriven approach by learning from large accessible EHRs is the desiderata.
In recent years, the Deep Learning (DL) model has made remarkable achievements due to its strong learning ability and flexible architecture [9][10][11][12][13]: some DL methods can model the sequential time relation of medical data. For example, RETAIN [3] utilizes gated recurrent unit (GRU) [14,15] to predict medical events, Dipole [1] uses Bidirectional RNN (BRNN) [16] to integrate the information in the past and the feature, and T-LSTM [8,17] injects the time decay effect to handle irregular time intervals. Using these methods, the EHR structure is modeled as Fig. 1b; Some DL methods can model the cooccurrence relation of medical data. For example, Word-2Vec [18,19], Med2Vec [4], and MiME [5] model the medical relations to better express the original data by the idea of representation learning [20][21][22][23]. Using these methods, the EHR structure is modeled as Fig. 1c.
However, no method can model both relations simultaneously. Because t there is a conflict between the two relations: The time relation makes data distributed longitudinally but the co-occurrence relation makes data distributed bipartite graph-like. If considering both these two relations, the EHR structure is shown in Fig. 1d.
To address the above issues, in this work, we define EHR as the hierarchical co-occurrence sequence and propose a novel model called Time-aware and Co-occurrence-aware Network (TCoN). TCoN can not only model the two relations simultaneously but also has the ability of interpretation. TCoN has the pre-train and fine-tune mechanism for the imbalanced data and is more accurate than all baselines in medical prediction tasks.

Materials and methods
In this section, we first introduce the MIMIC-III dataset and the data preprocessing process. Then, we describe the proposed methods in detail.

Dataset description and preprocessing
MIMIC-III is a freely accessible de-identified medical dataset, developed and maintained by the Massachusetts Institute of Technology Laboratory for Computational Physiology [33]. Based on MIMIC-III dataset, we selectively extract data and form three data sets:

Overall dataset
We extract records with more than one visit from MIMIC-III. The new dataset comprises 19,993 hospital admissions of 7537 patients and 260,326 diagnoses with 4,893 unique codes defined by the International Classification of Diseases-9 version (ICD-9). For one patient, the visit number is 2.66 on average. For one visit, the code number is 13.02 on average and up to 39.

Sepsis dataset
Following the latest sepsis 3.0 definition [34], we extract 1232 sepsis patients whose SOFA is greater than or equal to 2.

Heart failure dataset
According to ICD-9 code, we extract 1608 heart failure patients who have diagnoses of 428.x code.
In sepsis dataset and heart failure dataset, the extracted data is the records for the first time that these two diagnoses appear. And these two datasets are imbalanced. The detailed statistic is shown in Table 1.  Fig. 1 The data structure of EHR based on different methods. a Original EHR data structure. b EHR data structure based on time relation. c EHR data structure based on co-occurrence relation. d Data relation under our TCoN model. The data form b arranges codes in a random order, but different sequences have different effects on results. For example, the sequence 'heart disease -> influenza -> coronary' has closer relation between 'heart disease' and 'influenza' than the sequence 'heart disease -> coronary -> influenza' . The data form c can make every two codes have the equal relation, but if 'heart disease' , 'atrial fibrillation' and 'diabetes' are in three different visits, the equal relation will fail as there are different time intervals among them. The data form d is the combination. Meanwhile, the demographic information I is recorded to patients P.

Problem formulation
Definition 2 (Medical prediction tasks) They use a set of medical records R to predict the specific target Y = y 1 , y 2 , . . . y n . If n = 2 , it is a two-classification task. If n > 2 , it is a multi-classification task. The prediction task is f p : R → Y .

Definition 3 (Interpretation Path)
Interpretation uses the correlations R of medical pairs

Analysis strategy
Task 1 (Mortality prediction). To predict if the patient will die during the hospitalization.
Task 2 (Readmission prediction). To predict if the patient will be hospitalized again.
Task 3 (Disease prediction). Two disease prediction tasks: Sepsis and heart failure. Early diagnose is critical for improving patients' outcome [35].
Task 4 (Next diagnoses prediction). To predict the diagnoses of the patient in the next admission.
Note that Task 1, 2, 3 are binary classification tasks and Task 3 is a multi-classification task. (1) The area under the curve of Precision (P) and Recall (R). It is a better measure for imbalanced data [36].
Evaluation 3 (Accuracy@k). The probability of the positive predictions in top-k prediction values. It is the evaluation metric of multi-classification tasks.

TCoN model structure
As shown in Fig. 2, our TCoN model contains the code block and the visit block: The code block is implemented by Co-occurrence-aware Self-attention (CS-attention); The visit block is implemented by Time-aware Gated Recurrent Unit (T-GRU); Two blocks are connected by Attention connection.

CS-attention
Self-attention [32] in natural language processing considers the semantic and grammatical relations between different words in sentences. For each input, it has three vectors, Query (Q), Key (K), and Value (V). The multihead self-attention is designed as: In this work, we redesign the self-attention as CSattention (Eq. 5) to deal with the relations of EHR codes. CS-attention has two different heads-Local Head and Global Head. The local head learns the co-occurrence relations between every two codes in the same visit. A code is affected by the other codes equally. The global head learns the co-occurrence relations between every two codes in different visits. A code has different effects from the other codes according to different time intervals between visits. These two types of heads can learn a new representation C of each code C by its neighbors C nb . C is the original matrix of input codes. C is the new representation matrix. Q i , K i , V i is same as Eq. (4). i = 1 represents the local head and i = 2 represents the global head. d k is the dimension for Q and K . T is the time decay function g(�t) in Eq. 4. Both the number of local head and global head can be change.

T-GRU
As shown in Eq. 3, T-GRU comprises an update gate z t and a reset gate r t . They control the extent to which the previous state h t−1 is brought into the current state h t and how far the previous state is brought into the current candidate state h t . For modeling the time irregularity, we build a time gate d t . This gate takes time interval into account and control delivered information from the previous visit to the current visit by time decay function g(�t) . The time decay function can determine how much the history state can be injected into the current unit. In Eq. 3, x t is the current input data, W , U , b are parameters. The output is the current state h.
We propose three time decay functions (Eq. 7). Δt is the time interval between two visits, α is the decay rate. When α = 1 , the exponential form is more suitable for the small elapsed time, the logarithmic form is more suitable for the large elapsed time, and the reciprocal form is a compromise.

Attention connection
Between code block and visit block, we design the connection method (Eq. 8). Where X vi is the ith input of visit v , C i is the output matrix with each row for one ith visit's code, W β is a parameter vector. When we (6) Time gate : Reciprocal form g(�t) = 1 1 + α�t Logarithmic form g(�t) = 1 log(e + α�t) Exponential form g(�t) = e −α�t Fig. 2 TCoN structure consider the demographic information I . The input will be a concatenation form: X vi = concate β T C i , I i .
Besides, we propose a method to interpret TCoN. It is achieved by the correlation values among codes, visits, and predictions.

Interpretation path
It is based on the correlations R , containing two correlations: The code-code correlation is obtained from α of CS-attention. α ij means the effect of code j on code i , and large α ij means that code j could be the cause, complication, or early symptoms of code i ; The code-visit correlation is obtained from ∼ β of the Attention connection. Larger ∼ β means the closer relation.
The interpretation path is a code sequence obtained by the reverse lookup starting with the prediction results. For a prediction P , the last visit is v n . In v n , we find the code c ni that contributed the most to v n according to ∼ β . For c ni , we find the closest code c (n−1)i in visit v n−1 according to the largest α * C ni . Similarly, we find c (n−2)i , c (n−3)i , . . . c 1i . So far, we find a path c 1i → · · · → c ni → P . This path can be described: a disease c 1i most likely infers c 2i , then c 2i most likely infers c 3i , … and c (n−1)i most likely infers c ni , finally, c ni most likely causes P.
Finally, we apply a training method that enables TCoN to handle imbalanced data [37,38].

Pre-train and Fine-tune
In the pre-train process, we apply an auto-encoder network f ae with a minimum loss (Eq. 9) for the unsupervised representation learning task. In the fine-tune process, we use parameters of the encoder layer as the initial parameters of TCoN when training by the prediction objective in Eq. (10). For TCoN, the input layer is represented by Eq. (8), Skipconnection is Eq. (12), layer normalization [29] is Eq. (13), and feed forward layer is Eq. (14).

Complexity analysis
The self-attention-based algorithm is parallel, but the RNN-based algorithm is serial [32]. TCoN has both structures and they are connected in series. Thus, the com- d is the representation dimension and n is the sequence length. O n 2 · d is the complex of CS-attention with n 2 for operations of every two inputs. O n · d 2 is the complex of T-GRU with d 2 for sequential operation. In our data, the dimensionality d is smaller than the data length n , so that the complex of TCoN is O n · d 2 .

Experimental setup
For data, we right align the time series and use padding and masking to make them equal in length. Each code is represented by a one-hot vector with 4,893 dimensions (number of ICD-9 codes). Training, validation, and testing set is in 0.75:0.1:0.15 ratio.
For model, we set 2 local heads and 2 global heads. We choose α = 1 logarithmic time decay with year as the decay unit. We apply Adam Optimizer [39] with α = 0.001 , β 1 = 0.9 and β 2 = 0.999 . We use the learning rate decay method α current = α initial · γ global step decay steps with decay rate γ = 0.98 and decay step = 2000 [40]. Before the prediction task, we carry out the pre-train step and use the early stop with 5 epochs. We use the fivefold cross-validation. The code implementation is publicly available at https:// github. com/ SCXsu nchen xi/ MTGRU Baselines • Time-aware methods (RNN-based methods) • GRU [14]. It uses GRU to embed visits and make the final prediction. • T-LSTM [8]. It uses elapsed time weight to change previous memory in LSTM.
• Co-occurrence-aware methods (Word2Vec-based methods) [4]. It applies the skip-gram model and multi-layer perceptron to get the representation of codes and visits. • Dipole [1]. It uses BRNN along with three attention mechanisms to measure the relation of different visits for the final prediction.

Prediction results
TCoN predicts more accurately than all baselines. The results of binary classification (mortality, readmission, sepsis, and heart failure) and multi-classification (next diagnoses) are shown in Table 2(a, b). Baselines may not match EHR characteristics and partially model data features. For example, T-LSTM has the worst performance as it is not suitable for short visit sequences like MIMIC-III. TCoN performs well on imbalanced datasets. In binary classification tasks, all datasets are imbalanced, especially the sepsis dataset (6.16%). But the results show that the more imbalanced the data, the greater the advantage of TCoN over baselines.
TCoN can accurately predict multiple diagnoses in the next admission. In the multi-classification task, we evaluate methods with k = 5, 15, 25, 35. As shown in Table 2b, as k increases, the accuracies of all methods decrease, but the advantage of our approach is still obvious.

Model parameters experiments
We change the dimension of representation vector in hidden layers. The results in Fig. 3a show that TCoN performs better than other methods under all dimensions.
Then, we set different numbers of heads for TCoN. Figure 3b shows that the number of heads = 2 is the key turning point.

Case study of interpretation path
We choose a patient numbered 32,790 in MIMIC-III (a white man with 3 admission records and died at 80) to describe how TCoN produces the interpretation path. Figure 4a is the heat map of α for the death prediction. The diagnosis 'hypoxemia' contributes the most to the last admission as its weighted vector's norm is the biggest. For 'hypoxemia' , the closest diagnosis is 'pulmonary collapse' with the biggest α * i = 0.892 . For 'pulmonary collapse' , the closest diagnosis is 'unspecified pleural effusion' with the biggest α * i = 0.803 . And for 'unspecified pleural effusion' , the closest diagnosis is 'unspecified sleep apnea' with the biggest α * i = 0.782 . So far, an interpretation path 'unspecified sleep apnea -> unspecified pleural effusion -> pulmonary collapse -> Hypoxemia -> death' is found as shown in Fig. 4b. Figure 4c shows cases of interpretation paths of sepsis prediction and heart failure prediction. Each path is the summary results by using the most frequent diagnosis. Thus, we find sepsis-related pre-diagnoses/symptoms, such as 'Fever' , 'Chills' , 'Immunity disorders' , ' Anemia' and 'Coma' . And we find heart failure-related pre-diagnoses/ symptoms, such as 'Ventricular fibrillation' , 'Myocarditis' , 'Coronary atherosclerosis' and 'Hypertension' .

Discussion
In recent years, deep learning (DL) technology has shown its superior performance in medical applications [41][42][43][44], such as medical image recognition [45] and  [46]. And many methods have achieved good performance for specific disease prediction, such as Alzheimer's disease [47], sepsis [48], and heart disease [49,50]. However, most of them pursue the task accuracy but ignoring the interpretability. DL-based approaches are black-box models, which is not easy to understand for non-professionals, especially doctors without artificial intelligence backgrounds. Thus, the explainable DL method is needed. This study aims at this problem and puts forward a solution, interpretation path, to make the predictions explainable. In EHR, the patient's records are irregular in time due to the unpredictability of the diseases and inevitable data loss. The current disease could be more closely related to the disease a week ago than the disease a year ago [8,9]. Thus, the time perception mechanism is needed. This study aims at this issue and proposes a time gate to explicitly learn the irregular time information by the time decay function.
The experiments show that using two kinds of head for relations of inter-visit and intra-visit is necessary. The difference between these two relations is not just the time interval, but also the pathology. We emphasize the code relations are more likely to be complications in the same visit, but causations and continuities among different visits. For example, in our experiments, the relation of 'diabetes' with 'cellulitis and abscess of legs' in one visit is more prone to be a short-term complication, but the relation of 'diabetes' and 'long-term use of insulin' in two different visits is more prone to be causation. Thus, for each patient, we can give a disease association graph. The weight of the edges between two diagnoses in the same admission represents the adjoint coefficient, and the weight of the edges between two diagnoses in different admissions represents the causal coefficient. Figure 5 shows the diagnosis graph case of patient 32,790. The interpretation path is not symmetrical, which means α ij = α ji . α ij = # of i−j occurrences #of i occurrences and α ji = # of i−j occurrences # of j occurrences , they have different denominators. For example, code i , j , k represent the diagnoses of 'malaria' , 'fever' , 'periodic cold fever' respectively. In our experiment, i is mostly accompanied by j as α ij = 0.762 . But j is not always accompanied by i as α ji = 0.023 . It is mostly accompanied by code k with α ki = 0.701 . Comparing α ji and α ki , the results show that 'periodic cold fever' is a better explanation for 'malaria' than 'fever' . In research [51], 'periodic cold fever' is a special clinical manifestation of 'malaria' and there are very few other diseases with this symptom. It illustrates that our interpretable method can explain the results by reflecting the relation (such as complication, causation, and continuity) between the diagnoses and α * i is a more important standard to find the maximum co-occurrence code for i than α i * .
In medical applications, the data is usually imbalanced. The normal state of patients is the majority, while the disease records may be the small sample. But the small sample is more important for the disease prediction. Thus, a DL model should be robust on the imbalanced dataset. In this paper, our pre-train and fine-tune framework can help.
Further, there is room for further improvement. The current modeling method is based on pure EHRs data. Integrating prior information will make the results of the data relation modeling and medical prediction more accurate and reasonable. The available method is knowledge graph embedding based on ICD code. Besides, more data in EHRs such as doctor notes, medications, and laboratory tests can be used for better performance. Future work will focus on these aspects.

Conclusion
The data-driven medical prediction method based on interpretable deep learning is essential for healthcare management. In this paper, we propose an interpretable Time-aware and Co-occurrence-aware Network (TCoN) for data modeling and medical prediction. It can perceive hierarchical data structures with the time relation and the co-occurrence relation, give an interpretation path to explain the prediction, and build a diagnosis graph for every patient. The experiments show that TCoN outperforms the state-ofthe-art methods.