 Research
 Open Access
 Published:
Incorporating medical code descriptions for diagnosis prediction in healthcare
BMC Medical Informatics and Decision Making volume 19, Article number: 267 (2019)
Abstract
Background
Diagnosis aims to predict the future health status of patients according to their historical electronic health records (EHR), which is an important yet challenging task in healthcare informatics. Existing diagnosis prediction approaches mainly employ recurrent neural networks (RNN) with attention mechanisms to make predictions. However, these approaches ignore the importance of code descriptions, i.e., the medical definitions of diagnosis codes. We believe that taking diagnosis code descriptions into account can help the stateoftheart models not only to learn meaning code representations, but also to improve the predictive performance, especially when the EHR data are insufficient.
Methods
We propose a simple, but general diagnosis prediction framework, which includes two basic components: diagnosis code embedding and predictive model. To learn the interpretable code embeddings, we apply convolutional neural networks (CNN) to model medical descriptions of diagnosis codes extracted from online medical websites. The learned medical embedding matrix is used to embed the input visits into vector representations, which are fed into the predictive models. Any existing diagnosis prediction approach (referred to as the base model) can be cast into the proposed framework as the predictive model (called the enhanced model).
Results
We conduct experiments on two real medical datasets: the MIMICIII dataset and the Heart Failure claim dataset. Experimental results show that the enhanced diagnosis prediction approaches significantly improve the prediction performance. Moreover, we validate the effectiveness of the proposed framework with insufficient EHR data. Finally, we visualize the learned medical code embeddings to show the interpretability of the proposed framework.
Conclusions
Given the historical visit records of a patient, the proposed framework is able to predict the next visit information by incorporating medical code descriptions.
Background
The immense accumulation of Electronic Healthcare Records (EHR) makes it possible to directly predict patients’ future health status, which is done by analyzing their historical visit records [1–4]. Diagnosis prediction attracts considerable attention from both healthcare providers and researchers. It aims to predict the diagnosis information of patients in the following visits. There are two key challenges for diagnosis prediction task as follows: (1) designing an accurate and robust predictive model to handle the temporal, high dimensional and noisy EHR data; and (2) reasonably interpreting the advantages and effectiveness of the proposed models to both doctors and patients.
To address these challenges of diagnosis prediction task, many recurrent neural networks (RNN) based models [2–4] have been proposed. RETAIN [4] uses two recurrent neural networks with attention mechanisms to model the reverse time ordered EHR sequences. By employing a bidirectional recurrent neural network (BRNN), Dipole [2] enhances the prediction accuracy with different attention mechanisms. In order to guarantee the predictive performance, training the above mentioned models usually requires a lot of EHR data. However, there is a common problem for EHR data that is always existing medical codes of rare diseases. Those diagnosis codes infrequently appear in the EHR data. GRAM [3] has been proposed to overcome this issue. GRAM learns medical code representations by exploiting medical ontology information and the graphbased attention mechanism. For the rare medical codes, GRAM can alleviate the difficulties of learning their embeddings by considering their ancestors’ embeddings to guarantee the predictive performance. However, the performance of GRAM heavily depends on the choice of medical ontology. Thus, without specific input constraints, how to learn robust embeddings for medical codes is still the major challenge for accurate diagnosis prediction.
To resolve this challenge, we consider the “nature” of diagnosis codes, i.e., their medical descriptions. Actually, each diagnosis code has a formal description, which can be easily obtained from the Internet, such as Wikipedia or online medical websites. For example, the description of diagnosis code “428.32” is “Chronic diastolic heart failure” (http://www.icd9data.com/2015/Volume1/390459/420429/428/428.32.htm), and “Rheumatic heart failure (congestive)” is the description of diagnosis code “398.91” (http://www.icd9data.com/2015/Volume1/390459/393398/398/398.91.htm). Without considering the medical meanings of diagnosis codes, they are treated as two independent diseases in the EHR dataset. However, they both describe the same disease, i.e., “heart failure”. Thus, we strongly believe that incorporating the descriptions of diagnosis codes in the prediction should help the predictive models to improve the prediction accuracy and provide interpretable representations of medical codes, especially when the EHR data are insufficient.
The other benefit of incorporating diagnosis code descriptions is that it enables us to design a general diagnosis prediction framework. The input data of all the existing diagnosis prediction approaches are the same, i.e., a sequence of timeordered visits, and each visit consists of some diagnosis codes. Thus, all the existing approaches, including, but not limited to RETAIN, Dipole and GRAM, can be extended to incorporate the descriptions of diagnosis codes to further improve their predictive performance.
In this paper, we propose a novel framework for diagnosis prediction task. It should be noted that all of the stateoftheart diagnosis prediction approaches (referred to as base models) can be cast into the proposed framework. These base models enhanced by the proposed framework are thus called enhanced models. Specifically, the proposed framework consists of two components: diagnosis code embedding and predictive model. The diagnosis code embedding component aims to learn the medical representations of diagnosis codes according to their descriptions. In particular, for each word in the description, we obtain the pretrained vector representation from fastText [5]. Then the concatenation of all the words in each diagnosis code description is fed into a convolutional neural network (CNN) to generate the medical embeddings. Based on the learned medical embeddings of diagnosis codes, the predictive model component makes prediction. It first embeds the input visit information into a visitlevel vector representation with the code embeddings, and then feeds this vector into the predictive model, which can be any existing diagnosis prediction approach.
We use two real medical datasets to illustrate the superior ability of the proposed framework on the diagnosis prediction task compared with several stateoftheart approaches. Quantitative analysis is also conducted to validate the effectiveness of the proposed approaches with insufficient EHR data. Finally, we qualitatively analyze the interpretability of the enhanced approaches by visualizing the learned medical code embeddings against the embeddings learned by existing approaches. To sum up, we achieve the following contributions in this paper:
We realize the importance of obtaining diagnosis code embeddings from their descriptions which can be directly extracted from the Internet.
We propose a simple, but general and effective diagnosis prediction framework, which learns representations of diagnosis codes directly from their descriptions.
All the stateoftheart approaches can be cast into the proposed framework to improve the performance of diagnosis prediction.
Experimental results on two medical datasets validate the effectiveness of the proposed framework and the interpretability for prediction results.
Related Work
In this section, we briefly survey the work related to diagnosis prediction task. We first provide a general introduction about mining healthcare related data with deep learning techniques, and then survey the work of diagnosis prediction.
Deep Learning for EHR
Several machine learning approaches are proposed to mine medical knowledge from EHR data [1, 6–10]. Among them, deep learningbased models have achieved better performance compared with traditional machine learning approaches [11–13]. To detect the characteristic patterns of physiology in clinical time series data, stacked denoising autoencoders (SDA) are used in [14]. Convolutional neural networks (CNN) are applied to predict unplanned readmission [15], sleep stages [16], diseases [17, 18] and risk [19–21] with EHR data. To capture the temporal characteristics of healthcare related data, recurrent neural networks (RNN) are widely used for modeling disease progression [22, 23], mining time series healthcare data with missing values [24, 25], and diagnosis classification [26] and prediction [2–4, 27].
Diagnosis Prediction
Diagnosis prediction is one of the core research tasks in EHR data mining, which aims to predict the future visit information according to the historical visit records. Med2Vec [28] is the first unsupervised method to learn the interpretable embeddings of medical codes, but it ignores longterm dependencies of medical codes among visits. RETAIN [4] is the first interpretable model to mathematically calculate the contribution of each medical code to the current prediction by employing a reverse time attention mechanism in an RNN for binary prediction task. Dipole [2] is the first work to adopt bidirectional recurrent neural networks (BRNN) and different attention mechanisms to improve the prediction accuracy. GRAM [3] is the first work to apply graphbased attention mechanism on the given medical ontology to learn robust medical code embeddings even when lack of training data, and an RNN is used to model patient visits. KAME [29] uses highlevel knowledge to improve the predictive performance, which is build upon GRAM.
However, different from all the aforementioned diagnosis prediction models, the proposed diagnosis prediction framework incorporates the descriptions of diagnosis codes to learn embeddings, which greatly improves the prediction accuracy and provide interpretable prediction results against the stateoftheart approaches.
Methods
In this section, we first mathematically define the notations used in the diagnosis prediction task, introduce preliminary concepts, and then describe the details of the proposed framework.
Notations
We denote all the unique diagnosis codes from the EHR data as a code set \(\mathcal {C} = \{c_{1}, c_{2}, \cdots, c_{\mathcal {C}}\}\), where \(\mathcal {C}\) is the number of diagnosis codes. Let \(\mathcal {P}\) denote the number of patients in the EHR data. For the pth patient who has T visit records, the visiting information of this patient can be represented by a sequence of visits \(\mathcal {V}^{(p)} = \left \{V_{1}^{(p)}, V_{2}^{(p)}, \cdots, V_{T}^{(p)}\right \}\). Each visit \(V_{t}^{(p)}\) consists of multiple diagnosis codes, i.e., \(V_{t}^{(p)} \subseteq \mathcal {C}\), which is denoted by a binary vector \(\mathbf {x}_{t}^{(p)} \in \{0, 1\}^{\mathcal {C}}\). The ith element of \(\mathbf {x}_{t}^{(p)}\) is 1 if \(V_{t}^{(p)}\) contains the diagnosis code c_{i}. For simplicity, we drop the superscript (p) when it is unambiguous.
Each diagnosis code c_{i} has a formal medical description, which can be obtained from Wikipedia (https://en.wikipedia.org/wiki/List_of_ICD9_codes) or ICD9Data.com (http://www.icd9data.com/). We denote all the unique words which are used to describe all the diagnosis codes as \(\mathcal {W} = \{w_{1}, w_{2}, \cdots, w_{\mathcal {W}}\}\), and \(c_{i}^{\prime } \subseteq \mathcal {W}\) as the description of c_{i}, where \(\mathcal {W}\) is the number of unique words.
With the aforementioned notations, the inputs of the proposed framework are the set of code descriptions \(\left \{c_{1}^{\prime }, c_{2}^{\prime }, \cdots, c_{\mathcal {C}}^{\prime }\right \}\) and the set of timeordered sequences of patient visits \(\left \{\mathbf {x}_{1}^{(p)}, \mathbf {x}_{2}^{(p)}, \cdots, \mathbf {x}_{T1}^{(p)}\right \}_{p=1}^{\mathcal {P}}\). For each timestep t, we aim to predict the information of the (t+1)th visit. Thus, the outputs are \(\left \{\mathbf {x}_{2}^{(p)}, \mathbf {x}_{3}^{(p)}, \cdots, \mathbf {x}_{T}^{(p)}\right \}_{p=1}^{\mathcal {P}}\).
Preliminaries
In this subsection, we first introduce the commonly used techniques for modeling patients’ visits, and then list all the stateoftheart diagnosis prediction approaches.
Fully Connected Layer
Deep learning based models are commonly used to model patients’ visits. Among existing models, fully connected layer (FC) is the simplest approach, which is defined as follows:
where \(\mathbf {v}_{t} \in \mathbb {R}^{d}\) is the input data, d is the input dimensionality, \(\mathbf {W}_{c} \in \mathbb {R}^{\mathcal {C} \times d}\) and \(\mathbf {b}_{c} \in \mathbb {R}^{\mathcal {C}}\) are the learnable parameters.
Recurrent Neural Networks
Recurrent Neural Networks (RNNs) have been shown to be effective in modeling healthcare data [2–4, 30]. Note that we use “RNN” to denote any Recurrent Neural Network variants, such as LongShort Term Memory (LSTM) [31], TLSTM [32] and Gated Recurrent Unit (GRU) [33]. In this paper, GRU is used to adaptively capture dependencies among patient visit information. GRU has two gates: One is the reset gate r, and the other is the update gate z. The reset gate r computes its state from both the new input and the previous memory. The function of r is to make the hidden layer drop irrelevant information. The update gate z controls how much information should be kept around from the previous hidden state. The mathematical formulation of GRU can be described as follows:
where \(\mathbf {z}_{t} \in \mathbb {R}^{g}\) is the update gate at time t, g is the dimensionality of hidden states, σ() is the activation function, \(\mathbf {h}_{t} \in \mathbb {R}^{g}\) is the hidden state, \(\mathbf {r}_{t} \in \mathbb {R}^{g}\) is the reset gate at time t, \(\tilde {\mathbf {h}}_{t} \in \mathbb {R}^{g}\) represents the intermediate memory, and ∘ denotes the elementwise multiplication. Matrices \(\mathbf {W}_{z} \in \mathbb {R}^{g \times d}, \mathbf {W}_{r} \in \mathbb {R}^{g \times d}, \mathbf {W}_{h} \in \mathbb {R}^{g \times d}, \mathbf {U}_{z} \in \mathbb {R}^{g \times g}, \mathbf {U}_{r} \in \mathbb {R}^{g \times g}, \mathbf {U}_{h} \in \mathbb {R}^{g \times g}\) and vectors \(\mathbf {b}_{z} \in \mathbb {R}^{g}, \mathbf {b}_{r} \in \mathbb {R}^{g}, \mathbf {b}_{h} \in \mathbb {R}^{g}\) are parameters to be learned. For simplicity, the GRU can be represented by
where Ω denotes all the parameters of GRU.
Attention Mechanisms
Attention mechanisms aim to distinguish the importance of different input data, and attentionbased neural networks have been successfully used in diagnosis prediction task, including locationbased attention [2, 4], general attention [2], concatenationbased attention [2], and graphbased attention [3]. In the following, we introduce two commonly used attention mechanisms: locationbased and graphbased attention.
∙Locationbased Attention. Locationbased attention mechanism [2, 4] is to calculate the attention score for each visit, which solely depends on the current hidden state \(\mathbf {h}_{i} \in \mathbb {R}^{g}\) (1≤i≤t) as follows:
where \(\mathbf {W}_{\alpha } \in \mathbb {R}^{g}\) and \(b_{\alpha } \in \mathbb {R}\) are the parameters to be learned. According to Eq. (4), we can obtain an attention weight vector α=[α_{1},α_{2},⋯,α_{t}] for the t visits. Then the softmax function is used to normalize α. Finally, we can obtain the context vector c_{t} according to the attention weight vector α and the hidden states from h_{1} to h_{t} as follows:
We can observe that the context vector c_{t} is the weighted sum of all the visit information from time 1 to t.
∙Graphbased Attention. Graphbased attention [3] is proposed to learn robust representations of diagnosis codes even when the data volume is constrained, which explicitly employs the parentchild relationship among diagnosis codes with the given medical ontology to learn code embeddings.
Given a medical ontology \(\mathcal {G}\) which is a directed acyclic graph (DAG), each leaf node of \(\mathcal {G}\) is a diagnosis code c_{i} and each nonleaf node belongs to the set \(\hat {\mathcal {C}}\). Each leaf node has a basic learnable embedding vector \(\mathbf {e}_{i} \in \mathbb {R}^{d}\) (\(1 \leq i \leq \mathcal {C}\)), while \(\mathbf {e}_{\mathcal {C} + 1}, \cdots, \mathbf {e}_{\mathcal {C} + \hat {\mathcal {C}}}\) represent the basic embeddings of the internal nodes \(c_{\mathcal {C} + 1}, \cdots, c_{\mathcal {C} + \hat {\mathcal {C}}}\). Let \(\mathcal {A}(i)\) be the node set of c_{i} and its ancestors, then the final embedding of diagnosis code c_{i} denoted by \(\mathbf {g}_{i} \in \mathbb {R}^{d}\) can be obtained as follows:
where
θ(·,·) is a scalar value and defined as
where \(\mathbf {u}_{a} \in \mathbb {R}^{l}, \mathbf {W}_{a} \in \mathbb {R}^{l \times 2d}\) and \(\mathbf {b}_{a} \in \mathbb {R}^{l}\) are parameters to be learned. Finally, graphbased attention mechanism generates the medical code embeddings \(\mathbf {G} = \{\mathbf {g}_{1}, \mathbf {g}_{2}, \cdots, \mathbf {g}_{\mathcal {C}}\} \in \mathbb {R}^{d \times \mathcal {C}}\).
Base Models
Since the proposed framework is general, all the existing diagnosis prediction approaches can be cast into this framework and treated as base models. Table 1 shows the summary of all the stateoftheart approaches with the aforementioned techniques. The detailed implementation of these base models is introduced in “Experimental Setup” section.
The Proposed Framework
Different from graphbased attention mechanism which specifies the relationships of diagnosis codes with the given medical ontology, we aim to learn the diagnosis code embeddings directly from their medical descriptions. The main components of the proposed diagnosis prediction framework are diagnosis code embedding and predictive model. Diagnosis code embedding component is to learn the medical embeddings with code descriptions, which can embed the visit information into a vector representation. Predictive model component aims to predict the future visit information according to the embedded visit representations. Obviously, the proposed framework can be trained endtoend. Next, we provide the details of these two components.
Diagnosis Code Embedding
To embed the description of each diagnosis code into a vector representation, Convolutional Neural Networks (CNN) [34] can be employed. The benefit of applying CNN is to utilize layers with convolving filters to extract local features, which has shown its superior ability for natural language processing tasks, such as sentence modeling [35] and sentence classification [36].
Figure 1 shows the variant of the CNN architecture to embed each diagnosis code description \(c_{i}^{\prime }\) into a vector representation e_{i}. We first obtain the pretrained embedding of each word w_{j} denoted as \(\mathbf {l}_{j} \in \mathbb {R}^{k}\) from fastText [5], where k is the dimensionality. The description \(c_{i}^{\prime }\) with length n (padded where necessary) is represented as
where ⊕ is the concatenation operator. Let h denote the size of a word window, and then l_{i:i+h−1} represents the concatenation of h words from l_{i} to l_{i+h−1}. A filter \(\mathbf {W}_{f} \in \mathbb {R}^{h \times k}\) is applied on the window of h words to produce a new feature \(f_{i} \in \mathbb {R}\) with the ReLU activation function as follows:
where \(b_{f} \in \mathbb {R}\) is a bias term, and ReLU(f)= max(f,0). This filter is applied to each possible window of words in the whole description {l_{1:h},l_{2:h+1},⋯,l_{n−h+1:n}} to generate a feature map \(\mathbf {f} \in \mathbb {R}^{nh+1}\) as follows:
Next, max pooling technique [37] is used over the feature map to obtain the most important feature, i.e., \(\hat {f} = \max (\mathbf {f})\). In this way, one filter produces one feature. To obtain multiple features, we use m filters with varying window sizes. Here, we use q to denote the number of different window sizes. All the extracted features are concatenated to represent the embedding of each diagnosis code \(\mathbf {e}_{i} \in \mathbb {R}^{d}\) (d=mq). Finally, we can obtain the diagnosis code embedding matrix \(\mathbf {E} \in \mathbb {R}^{d \times \mathcal {C}}\), where e_{i} is the ith column of E.
The advantage of the proposed CNNbased diagnosis code embedding approach is that it easily makes the diagnosis codes with similar meanings obtain similar vector representations. Thus, for those diagnosis codes without sufficient training EHR data, they still can learn reasonable vector representations, which further helps the model to improve the predictive performance. In the following, we will introduce how to use the produced medical embeddings for the diagnosis prediction task.
Predictive Model
Based on the learned diagnosis code embedding matrix E, we can predict patients’ future visit information with a predictive model. Given a visit \(\mathbf {x}_{t} \in \{0, 1\}^{\mathcal {C}}\), we first embed x_{t} into a vector representation \(\mathbf {v}_{t} \in \mathbb {R}^{d}\) with E as follows:
where \(\mathbf {b}_{v} \in \mathbb {R}^{d}\) is the bias vector to be learned. Then v_{t} is fed into the predictive model to predict the (t+1)th visit information, i.e., \(\hat {\mathbf {y}}_{t}\). Next, we cast stateoftheart diagnosis prediction approaches into the proposed framework as the predictive models.
∙Enhanced MLP (MLP +). The simplest predictive model is only using a Multilayer Perceptron (MLP) with two layers: a fullyconnected layer and a softmax layer, i.e.,
where h_{t} is obtained from Eq. (1). This model works well when both the number of diagnosis codes and patients’ visits are small. However, MLP + does not use historical visit information for the prediction. To overcome the shortage of MLP +, we employ Recurrent Neural Networks (RNN) to handle more complicated scenarios.
∙Enhanced RNN (RNN +). For RNN +, the visit embedding vector v_{t} is fed into a GRU, which produces a hidden state \(\mathbf {h}_{t} \in \mathbb {R}^{g}\) as follows:
Then the hidden state h_{t} is fed through the softmax layer to predict the (t+1)th visit information as follows:
where \(\mathbf {W}_{c} \in \mathbb {R}^{\mathcal {C} \times g}\). Note that RNN + only uses the tth hidden state to make the prediction, which does not utilize the information of visits from time 1 to t−1. To consider all the information before the prediction, attentionbased models are proposed in the following.
∙Enhanced Attentionbased RNN (RNN _{a}+). According to Eq. (14), we can obtain all the hidden states h_{1},h_{2},⋯,h_{t}. Then locationbased attention mechanism is applied to obtain the context vector c_{t} with Eq. (5). Finally, the context vector c_{t} is fed into the softmax layer to make predictions as follows:
∙Enhanced Dipole (Dipole +). Actually, one drawback of RNN is that prediction performance will drop when the length of sequence is very large [38]. To overcome this drawback, Dipole [2] which uses bidirectional recurrent networks (BRNN) with attention mechanisms are proposed to improve the prediction performance.
Given the visit embeddings from v_{1} to v_{t}, a BRNN can learn two sets of hidden states: forward hidden states \(\overrightarrow {\mathbf {h}}_{1}, \cdots, \overrightarrow {\mathbf {h}}_{t}\) and backward hidden states \(\overleftarrow {\mathbf {h}}_{1}, \cdots, \overleftarrow {\mathbf {h}}_{t}\). By concatenating \(\overrightarrow {\mathbf {h}}_{t}\) and \(\overleftarrow {\mathbf {h}}_{t}\), we can obtain the final hidden state \(\mathbf {h}_{t} = [\overrightarrow {\mathbf {h}}_{t}; \overleftarrow {\mathbf {h}}_{t}]^{\top }\) (\(\mathbf {h}_{t} \in \mathbb {R}^{2g}\)). Then locationbased attention mechanism is used to produce the context vector \(\mathbf {c}_{t} \in \mathbb {R}^{2g}\) with Eq. (4) (\(\mathbf {W}_{\alpha } \in \mathbb {R}^{2g}\)). With the learned c_{t}, Dipole + can predict the (t+1)th visit information with a softmax layer, i.e., Eq. (16) with \(\mathbf {W}_{c} \in \mathbb {R}^{\mathcal {C} \times 2g}\).
∙Enhanced RETAIN (RETAIN +). RETAIN [4] is an interpretable diagnosis prediction model, which uses two reverse timeordered GRUs and attention mechanisms to calculate the contribution scores of all the appeared diagnosis codes before the prediction.
The visitlevel attention scores can be obtained using Eq. (4). For the codelevel attention scores, RETAIN employs the following function:
where \(\mathbf {W}_{\beta } \in \mathbb {R}^{d \times g}\) and \(\mathbf {b}_{\beta } \in \mathbb {R}^{d}\) are parameters. Then the context vector \(\mathbf {c}_{t} \in \mathbb {R}^{d}\) is obtained as follows:
With the generated context vector c_{t} and Eq. (16) (\(\mathbf {W}_{c} \in \mathbb {R}^{d}\)), RETAIN + can predict the (t+1)th patient’s health status.
∙Enhanced GRAM (GRAM +). GRAM [3] is the stateoftheart approach to learn reasonable and robust representations of diagnosis codes with medical ontologies. To enhance GRAM with the proposed framework, instead of randomly assigning the basic embedding vectors \(\mathbf {e}_{1}, \cdots, \mathbf {e}_{\mathcal {C}}\), we use diagnosis code descriptions to learn those embeddings, i.e., E. Note that the nonleaf nodes are still randomly assigned basic embeddings.
With the learned diagnosis code embedding matrix G as described in “Preliminaries” section, we can obtain visitlevel embedding v_{t} with Eq. (12) (i.e., replacing E to G). Using Eqs. (14) and (15), GRAM + predicts the (t+1)th visit information.
Remark: A key benefit of the proposed framework is its flexibility and transparency relative to all the existing diagnosis prediction models. Beyond all the aforementioned base approaches, more effective and complicated diagnosis prediction models can also be easily cast into the proposed framework.
Results
In this section, we first introduce two real world medical datasets used in the experiments, and then describe the settings of experiments. Finally, we validate the proposed framework on the two datasets.
RealWorld Datasets
Two medical claim datasets are used in our experiments to validate the proposed framework, which are the MIMICIII dataset [39] and the Heart Failure dataset.
∙ The MIMICIII dataset, a publicly available EHR dataset, consists of medical records of 7,499 intensive care unit (ICU) patients over 11 years. For this dataset, we chose the patients who made at least two visits.
∙ The Heart Failure dataset is an insurance claim dataset, which has 4,925 patients and 341,865 visits from the year 2004 to 2015. The patient visits were grouped by week [2], and we chose patients who made at least two visits. Table 2 shows more details about the two datasets.
Diagnosis prediction task aims to predict the diagnosis information of the next visit. In our experiments, we intend to predict the diagnosis categories as [2, 3], instead of predicting the real diagnosis codes. Predicting category information not only increases the training speed and predictive performance, but also guarantees the sufficient granularity of all the diagnoses. The nodes in the second hierarchy of the ICD9 codes are used as the category labels. For example, the category label of diagnosis code “428.43: Acute on chronic combined systolic and diastolic heart failure” is “Diseases of the circulatory system (390 −459)”.
Experimental Setup
We first introduce the stateoftheart diagnosis prediction approaches as base models, then describe the measures to evaluate the prediction results of all the approaches, and finally present the details of our experiment implementation.
Base Models
In our experiments, we use the following six approaches as base models:
∙MLP. MLP is a naive method, which first embeds the input visit x_{t} into a vector space v_{t}, and then uses Eq. (1) and Eq. (13) to predict the (t+1)th visit information.
∙RNN. RNN is a commonly used model. The input visit is first embedded into a visitlevel representation v_{t} with a randomly initialized embedding matrix. Then v_{t} is fed into a GRU, and the GRU outputs the hidden state h_{t} (Eq. (14)), which is used to predict the next visit information with Eq. (15).
∙RNN_{a} [2]. RNN_{a} adds the locationbased attention mechanism into RNN. After the GRU outputs the hidden states h_{1},h_{2},⋯,h_{t}, RNN _{a} employs Eqs. (4) and (5) to calculate the context vector c_{t}. Finally, RNN _{a} makes the predictions using the learned c_{t} and Eq. (16).
∙Dipole [2]. Dipole is the first work to apply bidirectional recurrent neural networks to diagnosis prediction task. In our experiments, we use locationbased attention mechanism. Compared with RNN _{a}, the difference is that Dipole uses two GRUs to generate the hidden states, and then concatenates these two sets of hidden states to calculate the context vector c_{t} with locationbased attention mechanism. ∙RETAIN [4]. RETAIN focuses on interpreting the prediction results with a twolevel attention model. RETAIN uses a reverse timeordered visit sequence to calculate the visitlevel attention score with Eq. (4). The other GRU is used to compute the codelevel attention weight with Eq. (17). The context vector c_{t} is obtained using Eq. (18). Based on this context vector, RETAIN predicts the (t+1)th diagnosis codes.
∙GRAM [3]. GRAM is the first work to employ medical ontologies to learn diagnosis code representations and predict the future visit information with recurrent neural networks. GRAM first learns the diagnosis code embedding matrix G with graphbased attention mechanism (Eq. (6)). With the learned G, the input visit x_{t} is embedded into a visitlevel representation v_{t}, which is fed into a GRU to produce the hidden state h_{t}. Equation (15) is used to make the final predictions.
For all the base models, we all design the corresponding enhanced approaches for comparison.
Evaluation Measures
To fairly evaluate the performance of all the diagnosis prediction approaches, we validate the results from aspects: visit level and code level with the measures precision @k and accuracy @k.
∙Visitlevel precision @k is defined as the correct diagnosis codes in top k divided by min(k,y_{t}), where y_{t} is the number of category labels in the (t+1)th visit.
∙ Given a visit V_{t} which contains multiple category labels, if the target label is in the top k guesses, then we get 1 and 0 otherwise. Thus, codelevel accuracy @k is defined by the number of correct label predictions divided by the total number of label predictions.
Visitlevel precision @k is used to evaluate the coarsegrained performance, while codelevel accuracy @k evaluates the finegrained performance. For all the measures, the greater values, the better performance. In the experiments, we vary k from 5 to 30.
Implementation Details
We extract the diagnosis code descriptions from ICD9Data.com. All the approaches are implemented with Theano 0.9.0 [40]. We randomly divide the datasets into the training, validation and testing sets in a 0.75:0.10:0.15 ratio. The validation set is used to determine the best values of parameters in the 100 training iterations. For training models, we use Adadelta [41] with a minbatch of 100 patients. The regularization (l_{2} norm with the coefficient 0.001) is used for all the approaches.
In order to fairly compare the performance, we set the same g=128 (i.e., the dimensionality of hidden states) for all the base models and the enhanced approaches except MLP and MLP +. For the proposed approaches on both datasets, the size of word embeddings is 300, the word windows (h’s) are set as 2, 3 and 4, and thus q=3. For each word window, we use m=100 filters. For all the base models, we set d=180 on the MIMICIII dataset and 150 on the Heart Failure dataset. For GRAM, l is 100.
Results of Diagnosis Prediction
Table 3 shows the visitlevel precision of all the base models and their corresponding enhanced approaches, and Table 4 lists the codelevel accuracy with different k’s. From these two tables, we can observe that the enhanced diagnosis prediction approaches improve the prediction performance on both the MIMICIII and Heart Failure datasets.
Performance Analysis for the MIMICIII Dataset
On the MIMICIII dataset, the overall performance of all the enhanced diagnosis prediction approaches is better than that of all the base models. Among all the proposed approaches, RETAIN + and MLP + achieve higher accuracy. MLP + does not use recurrent neural networks and directly predicts the future diagnosis information with the learned visit embedding v_{t}. RETAIN + utilizes the context vector which learns from visitlevel and codelevel attention scores, and the learned visit embeddings to make the final predictions. However, all the remaining proposed approaches use the hidden states outputted from GRUs to predict the next visit information. From the above analysis, we can conclude that directly adding visit embeddings into the final prediction can improve the predictive performance on the MIMICIII dataset. This is reasonable because the average length of visits is small on the MIMICIII dataset. The shorter visits may not help the RNNbased models to learn correct hidden states, and thus those methods can not achieve the highest accuracy.
This observation can also be found from the performance of all the base models. Compared with the naive base model MLP, the precision or accuracy of all the four RNNbased approaches is lower, including RNN, RNN _{a}, Dipole and RETAIN. This again confirms that RNNbased models cannot work well with short sequences. Among all the RNNbased approaches, locationbased attention models, RNN _{a} and Dipole, perform worse than RNN and RETAIN, which shows that learning attention mechanisms needs abundant EHR data. Compared with RNN, both the precision and accuracy of RETAIN are still higher. This demonstrates that directly using visit embedding in the final prediction may achieve better performance for the datasets with shorter visit sequences. GRAM can achieve comparable performance with the naive base model MLP. It proves that employing external information can compensate for the lack of training EHR data in diagnosis prediction task.
Here is an interesting observation: As expected, the performance improves as k increases, except the visitlevel accuracy on the MIMICIII dataset, due to the insufficiency of training data. Compared with the labels with abundant data, they obtain lower probabilities in the predictions. Thus, for the visits containing these labels without sufficient data, the number of correct predictions when k is 10 or 15 may be the same with that when k=5. However, they are divided by a bigger min(k,y_{t}), which leads to the observation that the average performance is worse than that with k=5.
Performance Analysis for the Heart Failure Dataset
On the Heart Failure dataset, the enhanced approaches still perform better than the corresponding base models, especially GRAM + which achieves much higher accuracy than other approaches. The reason is that GRAM + not only uses medical ontologies to learn robust diagnosis code embeddings, but also employs code descriptions to further improve the performance, which can be validated from the comparison between the performance of GRAM and GRAM +.
Among all the approaches, both precision and accuracy of RETAIN are the lowest, which shows that directly using the visitlevel embeddings in the final prediction may not work on the Heart Failure dataset, which can also be observed from the performance of MLP. However, taking code descriptions into consideration, the performance enormously increases. When k=5, the visitlevel precision and codelevel accuracy of RETAIN improve 37% and 42% respectively. The performance of MLP is better than that of RETAIN, but it is still lower than other RNN variants. This illustrates that with complicated EHR datasets, simply using multilayer perceptrons cannot work well. Though learning medical embeddings of diagnosis codes improves the predictive performance, the accuracy of MLP + is still lower than that of most approaches. This directly validates that applying recurrent neural networks to diagnosis prediction task is reasonable.
For the two locationbased attention approaches, RNN _{a} and Dipole, the performance is better than that of RNN, which demonstrates that attention mechanisms can help the models to enhance the predictive ability. Comparison between RNN _{a} and Dipole confirms that when the size of visit sequences is big, bidirectional recurrent neural networks can remember more useful information and perform better than one directional recurrent neural networks.
Based on all the above analysis, we can safely conclude that learning diagnosis code embeddings with descriptions indeed helps all the stateoftheart diagnosis prediction approaches to significantly improve the performance on different real world datasets.
Discussions
The main contribution of this work is to incorporate code descriptions to improve the prediction performance of stateoftheart models. The experimental results on two real datasets confirm the effective of the proposed framework. Next, we further discuss the performance changes with the degree of data sufficiency and the representations leaned by the proposed framework.
Data Sufficiency
In healthcare, it is hard to collect enough EHR data for those rare diseases. In order to validate the sensitivity of all the diagnosis prediction approaches to data sufficiency, the following experiments are conducted on the MIMICIII dataset. We first calculate the frequency of category labels appeared in the training data, then rank these labels according to the frequency, and finally divide them into four groups: 025, 2550, 5075 and 75100. The category labels in group 025 are the most rare ones in the training data, while the labels in group 75100 are the most common ones. We finally compute the average accuracy of labels in each group. The codelevel accuracy @20 on the MIMICIII dataset is shown in Fig. 2. Xaxis denotes all the base models and their corresponding enhanced approaches, and Yaxis represents the average accuracy of the approaches.
From Fig. 2, we can observe that the accuracy of all the enhanced diagnosis prediction approaches is higher than that of all the base models in the first three groups. Even though MLP and RETAIN achieve higher accuracy compared with RNN, RNN _{a} and Dipole as shown in Table 4, the accuracy of both approaches is 0 in group 025. However, when generalizing the proposed framework on MLP and RETAIN, they all make some correct predictions for rare diseases. This observation also can be found in groups 2550 and 5070. Therefore, this observation validates that considering the medical meanings of diagnosis codes indeed helps existing models to enhance their predictive ability even without sufficient training EHR data.
In Fig. 2d, all the labels have sufficient and abundant training EHR data. Thus, all the approaches achieve comparable performance. This result again confirms that the enhanced approaches improve the predictive performance on those rare diseases, i.e., the labels without sufficient training EHR records. Among all the base models, GRAM obtains the highest accuracy in groups 025, 2550 and 5075, which illustrates the effectiveness of incorporating external medical knowledge. Furthermore, learning medical embeddings with ontologies still improves the predictive accuracy, which can be observed from both Fig. 2 and Table 4.
Interpretable Representation
For diagnosis prediction task, interpreting the learned medical code embeddings is significantly important. Thus, we conduct the following experiments to qualitatively demonstrate the learned representations by all the approaches on the MIMICIII dataset. W randomly select 2000 diagnosis codes and then plot them on a 2D space with tSNE [42] shown in Fig. 3. The color of the dots represents the first disease categories in CCS multilevel hierarchy as [3]. We can observe that except GRAM, the remaining baselines cannot learn interpretable representations. However, after considering the semantic meanings learned from diagnosis code descriptions, all the proposed approaches can learn some interpretable cluster structures in the representations. Especially for GRAM +, it not only maintains the advantages of GRAM, but also improves the prediction accuracy. From Fig. 3, we come to a conclusion that the proposed semantic diagnosis prediction framework is effective and interpretable even when the training EHR data are insufficient.
Conclusions
Diagnosis prediction from EHR data is a challenging yet practical research task in healthcare domain. Most stateoftheart diagnosis prediction models employ recurrent neural networks to model the sequential patients’ visit records, and exploit attention mechanisms to improve the predictive performance and provide interpretability for the prediction results. However, all the existing models ignore the medical descriptions of diagnosis codes, which are significantly important to diagnosis prediction task, especially when the EHR data are insufficient.
In this paper, we propose a novel and effective diagnosis prediction framework, which takes the medical meanings of diagnosis codes into account when predicting patients’ future visit information. The proposed framework includes two basic components: diagnosis code embedding and predictive model. In the diagnosis code embedding component, medical representations of diagnosis codes are learned from their descriptions with a convolutional neural network on top of pretrained word embeddings. Based on the learned embeddings, the input visit information is embedded into a visitlevel vector representation, which is then fed into the predictive model component. In the predictive model component, all the stateoftheart diagnosis prediction models are redesigned to significantly improve the predictive performance by considering diagnosis code meanings. Experimental results on two real world medical datasets prove the effectiveness and robustness of the proposed framework for diagnosis prediction task. An experiment is designed to illustrate that the enhanced diagnosis prediction approaches outperform all the corresponding stateoftheart approaches under insufficient EHR data. Finally, the learned medical code representations are visualized to demonstrate the interpretability of the proposed framework.
Availability of data and materials
The MIMICIII dataset can be obtained from the line: https://mimic.physionet.org/gettingstarted/access/.
Abbreviations
 BRNN Bidirectional recurrent neural network; CCS:

Clinical classifications software
 CNN:

Convolutional neural networks
 DAG:

Directed acyclic graph
 Dipole:

Attentionbased bidirectional recurrent neural networks
 Dipole +:

Enhanced attentionbased bidirectional recurrent neural networks
 EHR:

Electronic health records
 GRAM:

Graphbased Attention model
 GRAM +:

Enhanced graphbased attention model
 GRU:

Gated recurrent unit
 LSTM:

Longshort term memory
 MIMICIII:

Medical information mart for intensive care
 MLP:

Multilayer perceptron
 MLP +:

Enhanced multilayer perceptron
 RETAIN:

Reverse time attention mechanism
 RETAIN +:

Enhanced reverse time attention mechanism
 RNN:

Recurrent neural networks
 RNN +:

Enhanced recurrent neural network
 RNN _{a} :

Attentionbased recurrent neural network
 RNN _{a}+:

Enhanced attentionbased recurrent neural network
 SDA:

Stacked denoising autoencoders
 TLSTM:

Timeaware longshort term memory
References
 1
Miotto R, Wang F, Wang S, Jiang X, Dudley JT. Deep learning for healthcare: Review, opportunities and challenges. Brief Bioinform. 2017; 19(6):1236–46.
 2
Ma F, Chitta R, Zhou J, You Q, Sun T, Gao J. Dipole: Diagnosis prediction in healthcare via attentionbased bidirectional recurrent neural networks. In: KDD. New York: ACM: 2017. p. 1903–11.
 3
Choi E, Bahadori MT, Song L, Stewart WF, Sun J. Gram: Graphbased attention model for healthcare representation learning. In: KDD. New York: ACM: 2017. p. 787–95.
 4
Choi E, Bahadori MT, Sun J, Kulas J, Schuetz A, Stewart W. Retain: An interpretable predictive model for healthcare using reverse time attention mechanism. In: NIPS. Curran Associates, Inc.: 2016. p. 3504–12.
 5
Bojanowski P, Grave E, Joulin A, Mikolov T. Enriching word vectors with subword information. arXiv preprint arXiv:1607.04606. 2016.
 6
Dua S, Acharya UR, Dua P. Machine Learning in Healthcare Informatics vol. 56; 2014.
 7
Suo Q, Zhong W, Ma F, Yuan Y, Huai M, Zhang A. Multitask sparse metric learning for monitoring patient similarity progression. In: ICDM. IEEE: 2018. p. 477–86.
 8
Ma F, Meng C, Xiao H, Li Q, Gao J, Su L, Zhang A. Unsupervised discovery of drug sideeffects from heterogeneous data sources. In: KDD. New York: ACM: 2017. p. 967–76.
 9
Yuan Y, Xun G, Ma F, Wang Y, Du N, Jia K, Su L, Zhang A. Muvan: A multiview attention network for multivariate temporal data. In: ICDM. IEEE: 2018. p. 717–26.
 10
Ma F, Wang Y, Xiao H, Yuan Y, Chitta R, Zhou J, Gao J. A general framework for diagnosis prediction via incorporating medical code descriptions. In: BIBM. IEEE: 2018. p. 1070–5.
 11
Zhang S, Xie P, Wang D, Xing EP. Medical diagnosis from laboratory tests by combining generative and discriminative learning. arXiv preprint arXiv:1711.04329. 2017.
 12
Zhang Y, Chen R, Tang J, Stewart WF, Sun J. Leap: Learning to prescribe effective and safe treatment combinations for multimorbidity. In: KDD. New York: ACM: 2017. p. 1315–24.
 13
Yuan Y, Xun G, Ma F, Suo Q, Xue H, Jia K, Zhang A. A novel channelaware attention framework for multichannel eeg seizure detection via multiview deep learning. In: BHI. IEEE: 2018. p. 206–9.
 14
Che Z, Kale D, Li W, Bahadori MT, Liu Y. Deep computational phenotyping. In: KDD. New York: ACM: 2015. p. 507–16.
 15
Nguyen P, Tran T, Wickramasinghe N, Venkatesh S. Deepr: A convolutional net for medical records. IEEE J Biomed Health Inform. 2017 Jan; 21(1):22–30.
 16
Yuan Y, Jia K, Ma F, Xun G, Wang Y, Su L, Zhang A. Multivariate sleep stage classification using hybrid selfattentive deep learning networks. In: BIBM. IEEE: 2018. p. 963–8.
 17
Suo Q, Ma F, Yuan Y, Huai M, Zhong W, Zhang A. Personalized disease prediction using a cnnbased similarity learning method. In: BIBM. IEEE: 2017. p. 811–6.
 18
Suo Q, Ma F, Yuan Y, Huai M, Zhong W, Gao J, Zhang A. Deep patient similarity learning for personalized healthcare. IEEE Trans NanoBioscience. 2018; 17(3).
 19
Cheng Y, Wang F, Zhang P, Hu J. Risk prediction with electronic health records: A deep learning approach. In: SDM. SIAM: 2016. p. 432–40.
 20
Che Z, Cheng Y, Zhai S, Sun Z, Liu Y. Boosting deep learning risk prediction with generative adversarial networks for electronic health records. In: ICDM. IEEE: 2017. p. 787–92.
 21
Ma F, Jing G, Suo Q, You Q, Zhou J, Zhang A. Risk prediction on electronic health records with prior medical knowledge. In: KDD. New York: ACM: 2018. p. 1910–19.
 22
Pham T, Tran T, Phung D, Venkatesh S. Deepcare: A deep dynamic memory model for predictive medicine. In: PAKDD. Springer: 2016. p. 30–41.
 23
Che C, Xiao C, Liang J, Jin B, Zho J, Wang F. An rnn architecture with dynamic temporal matching for personalized predictions of parkinson’s disease. In: SDM. SIAM: 2017. p. 198–206.
 24
Che Z, Purushotham S, Cho K, Sontag D, Liu Y. Recurrent neural networks for multivariate time series with missing values. arXiv preprint arXiv:1606.01865. 2016.
 25
Lipton ZC, Kale DC, Wetzel R. Modeling missing data in clinical time series with rnns. In: MLH: 2016. p. 253–70.
 26
Lipton ZC, Kale DC, Elkan C, Wetzell R. Learning to diagnose with lstm recurrent neural networks. In: ICLR: 2015.
 27
Choi E, Bahadori MT, Schuetz A, Stewart WF, Sun J. Doctor ai: Predicting clinical events via recurrent neural networks. In: MLH: 2016. p. 301–18.
 28
Choi E, Bahadori MT, Searles E, Coffey C, Thompson M, Bost J, TejedorSojo J, Sun J. Multilayer representation learning for medical concepts. In: KDD. New York: ACM: 2016. p. 1495–504.
 29
Ma F, You Q, Xiao H, Chitta R, Zhou J, Gao J. Kame: Knowledgebased attention model for diagnosis prediction in healthcare. In: CIKM. New York: ACM: 2018. p. 743–52.
 30
Suo Q, Ma F, Canino G, Gao J, Zhang A, Veltri P, Gnasso A. A multitask framework for monitoring health conditions via attentionbased recurrent neural networks. In: AMIA. American Medical Informatics Association: 2017. p. 1665–74.
 31
Hochreiter S, Schmidhuber J. Long shortterm memory. Neural Comput. 1997; 9(8):1735–80.
 32
Baytas IM, Xiao C, Zhang X, Wang F, Jain AK, Zhou J. Patient subtyping via timeaware lstm networks. In: KDD. New York: ACM: 2017. p. 65–74.
 33
Cho K, Van Merriënboer B, Bahdanau D, Bengio Y. On the properties of neural machine translation: Encoderdecoder approaches. arXiv preprint arXiv:1409.1259. 2014.
 34
LeCun Y, Bottou L, Bengio Y, Haffner P. Gradientbased learning applied to document recognition. Proc IEEE. 1998; 86(11):2278–324.
 35
Blunsom P, Grefenstette E, Kalchbrenner N. A convolutional neural network for modelling sentences. In: ACL. Association for Computational Linguistics: 2014. p. 655–65.
 36
Kim Y. Convolutional neural networks for sentence classification. In: EMNLP. Association for Computational Linguistics: 2014. p. 1746–51.
 37
Collobert R, Weston J, Bottou L, Karlen M, Kavukcuoglu K, Kuksa P. Natural language processing (almost) from scratch. J Mach Learn Res. 2011; 12(Aug):2493–537.
 38
Schuster M, Paliwal KK. Bidirectional recurrent neural networks. IEEE Trans Sig Process. 1997; 45(11):2673–81.
 39
Johnson AE, Pollard TJ, Shen L, Liwei HL, Feng M, Ghassemi M, Moody B, Szolovits P, Celi LA, Mark RG. Mimiciii, a freely accessible critical care database. Sci Data. 2016; 3:160035.
 40
Team TTD. Theano: A python framework for fast computation of mathematical expressions. arXiv preprint arXiv:1605.02688. 2016.
 41
Zeiler MD. Adadelta: An adaptive learning rate method. arXiv preprint arXiv:1212.5701. 2012.
 42
Maaten Lvd, Hinton G. Visualizing data using tsne. J Mach Learn Res. 2008; 9(Nov):2579–605.
Acknowledgements
This work is supported in part by the US National Science Foundation under grant IIS1747614. The authors would like to thank NVIDIA Corporation with the donation of the Titan Xp GPU. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation. A short version of this work titled “A General Framework for Diagnosis Prediction via Incorporating Medical Code Descriptions” was presented at the International Conference on Bioinformatics and Biomedicine in Madrid on 36 December 2018.
About this supplement
This article has been published as part of BMC Medical informatics and Decision Making Volume 19 Supplement 6, 2019: Selected articles from the IEEE BIBM International Conference on Bioinformatics & Biomedicine (BIBM) 2018: medical informatics and decision making. The full contents of the supplement are available online at https://bmcmedinformdecismak.biomedcentral.com/articles/supplements/volume19supplement6.
Funding
Publication costs were funded in part by the US National Science Foundation under grant IIS1747614.
Author information
Affiliations
Contributions
FM, RC, JZ and JG developed the study concept and designed the model. RC acquired the EHR data, and FM processed the EHR data. FM, YW, HX and YY acquired and processed the medical code descriptions. FM, YW and YY programmed the CNN algorithm. FM carried out the experiments. FM, HX and JG analyzed the data and the experimental results. FM, RC, JZ and JG drafted the manuscript. All authors were involved in the revision of the manuscript. All authors read and approved the final manuscript.
Corresponding author
Correspondence to Fenglong Ma.
Ethics declarations
Ethics approval and consent to participate
Not applicable.
Consent for publication
Not applicable.
Competing interests
The authors declare that they have no competing interests.
Additional information
Publisher’s Note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Rights and permissions
Open Access This article is distributed under the terms of the Creative Commons Attribution 4.0 International License (http://creativecommons.org/licenses/by/4.0/), which permits unrestricted use, distribution, and reproduction in any medium, provided you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons license, and indicate if changes were made. The Creative Commons Public Domain Dedication waiver(http://creativecommons.org/publicdomain/zero/1.0/) applies to the data made available in this article, unless otherwise stated.
About this article
Cite this article
Ma, F., Wang, Y., Xiao, H. et al. Incorporating medical code descriptions for diagnosis prediction in healthcare. BMC Med Inform Decis Mak 19, 267 (2019). https://doi.org/10.1186/s1291101909612
Published:
Keywords
 Healthcare informatics
 Diagnosis prediction
 Medical code embeddings