In this section, we give a detailed description of the proposed PAVE, which consists of four main parts. First, an embedding module represents medical events, variable values and the happening time as vectors. Then, a self-attention module is used to capture the pattern information between events. Next, a pattern attention module is followed to fuse all the pattern features, which are sent to a fully connected layer to predict the clinical outcomes. The framework of PAVE is shown in Fig. 1.
Problem definition and notation
The risk prediction task can be regarded as a binary classification problem. Given a sequence of medical events, the framework aims to predict if the patient will have a certain medical event (e.g., diagnosis codes, mortality) in the future.
A patient’s EHRs data consist of two main parts: static information and dynamic information. Static information is his/her demographics, such as gender and age. We represent each patient’s demographics as one-hot vectors. Patients’ ages are divided into several age groups (e.g., 20–29, 30–39).
The dynamic information is his/her historical records, including diagnosis codes, medications, lab tests, vital signs (patients in ICU have vital sign data). Each diagnosis code is Boolean-value data and others are real-value data. There could be several diagnosis codes, many collections of lab tests and vital sign data in one visit. There are usually some missing values in some items of the lab test and vital signs in each collection.
Given a patient, his/her data are denoted as (x, \({\hat{y}}\)). The input data x includes the input demographics d and a sequence of n EHRs records, denoted as \((e_1, t_1), (e_2, t_2),\ldots , (e_n, t_n)\). For each event \(e_i\), its happening time is represented as \(t_i\). \({\hat{y}}\) is the risk ground truth.
Embedding module
In this subsection, we present a new event embedding with the consideration of variable values, the corresponding happening time and patient demographics. As Shown in Fig. 2, the embedding module takes event (as well as the values), happening time, demographics as input and adopts three embedding layers to project them into vectors.
Time embedding
The first embedding layer is time embedding layer, which map the happening time \(t_i\) into a vector \(v_i^t \in R^k\). \(t_i\) is the interval time between the happening time of event \(e_i\) and the last event time. The \(j{\mathrm{th}}\) dimension of \(v_i^t\) is computed as:
$$\begin{aligned} v_{i, j}^t = {\left\{ \begin{array}{ll} \sin \left( \frac{t_i * j}{t_m * k}\right) , &\quad {\text {if j is even}} \\ \cos \left( \frac{t_i * j}{t_m * k}\right) , &\quad {\text {if j is odd}}, \end{array}\right. } \end{aligned}$$
(1)
where \(t_m\) is the maximum of time intervals, k denotes the dimension of \(v_{i}^t\).
Value embedding
The second is medical event value embedding layer, which map each event \(e_i\) and its value \(v_i\) into a vector \(v_i^e \in R^k\). Given an event and its value, we map the event into a vector \(v_i^{e,e}\) via a fully connected layer. If the event value is Boolean value (e.g., diagnosis code), we directly use \(v_i^{e,e}\) as \(v_i^e\). Otherwise for float value events (e.g., lab tests), given the value \(v_i\) of event \(e_i\), the value embedding layer generate a vector \(v_i^{e,v}\) in the same way as time embedding layer. The \(j{\mathrm{th}}\) dimension of \(v_i^{e,v}\) is computed as:
$$\begin{aligned} v_{i, j}^{e,v} = {\left\{ \begin{array}{ll} \sin \left( \frac{(v_i - v_{min}) * j}{(v_{max} - v_{min}) * k}\right) , &\quad {\text {if j is even}} \\ \cos \left( \frac{(v_i - v_{min}) * j}{(v_{max} - v_{min}) * k}\right) , &\quad {\text {if j is odd,}} \end{array}\right. } \end{aligned}$$
(2)
where \(v_{min}\) and \(v_{max}\) are the minimum and maximum values of the corresponding variable, k denotes the dimension of \(v_i^{e,v}\). Given \(v_i^{e,v}\) and \(v_i^{e,e}\), a linear function is used to combine them to \(v_i^e\):
$$\begin{aligned} v_i^e = v_i^{e,v} W_v + v_i^{e,e} W_e + b_e, \end{aligned}$$
(3)
where \(W_v, W_e \in R^{k \times k}\) and \(b_e \in R^k\) are learnable parameters.
Demographic embedding
The third embedding layer is demographic embedding layer, which embeds d into a matrix \(v^d \in R^{|d| \times k}\). A demographic attention mechanism is leveraged to attend to the demographic information.
$$\begin{aligned} v_i^{d,a}& = \sum _{j=1}^{|d|} v_j^d * \alpha _j^{d,i} \nonumber \\ \alpha _j^{d,i}& = \frac{exp(\beta _j^{d,i}) }{ \sum _{j=u}^{|d|} exp(\beta _u^{d,i})}\nonumber \\ \beta _j^{d,i}& = v_j^d W_{v,d} + v_i^e W_{v,e}, \end{aligned}$$
(4)
where \(W_{v,d}, W_{v, e} \in R^k\) are learnable parameter, \(v_j^d \in R^k\) denotes the \(j{\mathrm{th}}\) dimension of \(v^d\), \(v_i^{d,a} \in R^k\) is the demographic attention result.
Given the embedding and attention results (i.e., \(v_i^e\), \(v_i^t\) and \(v_i^{d,a}\)), using a concatenation operation and a fully connected layer, the \(i{\mathrm{th}}\) event and the patient’s demographics are projected into an embedding vector \(q_i \in R^k\).
$$\begin{aligned} q_i = v_i^{e} W_{q,e} + v_i^{t} W_{q,t} + v_i^{d,a} W_{q,d} + b_q, \end{aligned}$$
(5)
where \(W_{q,e}, W_{q,t}, W_{q,d} \in R^{k \times k}\) and \(b_q \in R^k\) are learnable parameters.
Self attention module
Given a patient, his/her sequence of final embeddings of events \(q = \{q_1, q_2,\ldots , q_n\}\) are input to self-attention module to capture useful patterns between related events. Three fully connected layers are used to map q into three matrices \(Q, K, V \in R^{n \times k}\), which are queries, keys and values respectively. The self-attention output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key. Specifically, we compute the dot products of each query \(Q_i\) with other keys \(K_j\) and calculate the attention weight \(\alpha _{ij}\) with a softmax function. Obtaining the weight, the sum of query event’s value and attention result of key events’ values is output as the pattern attention outcome \(P \in R^{n \times k}\). The \(i{\mathrm{th}}\) dimension P is computed as follows:
$$\begin{aligned} P_i& = V_i + \sum _j \alpha _{ij} V_j \nonumber \\ \alpha _{ij}& = \frac{\exp (\beta _{ij})}{ \sum _l \exp (\beta _{il}) } \nonumber \\ \beta _{ij}& = Q_i K_j^T \nonumber \\ Q_i& = q_i W_Q + b_Q \nonumber \\ K_i& = q_i W_K + b_K \nonumber \\ V_i& = q_i W_V + b_V, \end{aligned}$$
(6)
where \(W_Q, W_K, W_V \in R^{k \times k}\) and \(b_Q, b_K, b_V \in R^{k}\) are learnable parameters. Given two events i and j, the product between query \(Q_i\) and key \(K_j\) represents their relevance \(\beta _{ij}\). A softmax layer is followed to generate the attention weight \(\alpha _{ij}\). Finally, a soft attention layer is used to produce the pattern vector \(P_j\). The self-attention module can capture two-event patterns. By stacking more self-attention layer, PAVE also has the potential to capture more complex medical patterns with more events.
Pattern attention module
There are various patterns in each patient’s EHRs data, only some are useful for risk prediction goal. Given the pattern embeddings \(P \in R^{n \times k}\), a pattern attention mechanism is used to attend to the meaningful patterns.
$$\begin{aligned} h& = \sum _i \gamma _i P_i\nonumber \\ \gamma _i& = \frac{exp (\theta _i)}{ \sum _j exp(\theta _j)} \nonumber \\ \theta _i& = P_i W_p + b_p, \end{aligned}$$
(7)
where \(W_p \in R^{k}\) and \(b_p \in R\) is learnable parameters. Given a medical event pattern i for a patient, a fully connected layer is adopted to compute its relevance \(\theta _i\) to the risk prediction task. Then a softmax layer is followed to compute the weights for different patterns. Finally a soft attention is used to combine various patterns and produce a vector h, which contains the patient’s clinical risk information.
Objective function
A fully connected layer and sigmoid layer are followed to predict the risk probability:
$$\begin{aligned} y = sigmoid(h W_h + b_h), \end{aligned}$$
(8)
where \(W_h \in R^k\) and \(b_h \in R\) are learnable parameters. The cross-entropy between ground truth \({\hat{y}}\) and predicted result y is used to compute loss:
$$\begin{aligned} L(y, {\hat{y}}) = -({\hat{y}} log(y) + (1 - {\hat{y}}) log(1 - y)). \end{aligned}$$
(9)
Interpretability
The interpretability is that PAVE can compute each pattern’s contribution to the output. Given pattern (i, j), including event i and event j, the contribution \(C_{ij}\) is calculated as follows:
$$\begin{aligned} C_{ij} = \gamma _i \alpha _{ij}. \end{aligned}$$
(10)