Leveraging auxiliary measures: a deep multi-task neural network for predictive modeling in clinical research

Background Accurate predictive modeling in clinical research enables effective early intervention that patients are most likely to benefit from. However, due to the complex biological nature of disease progression, capturing the highly non-linear information from low-level input features is quite challenging. This requires predictive models with high-capacity. In practice, clinical datasets are often of limited size, bringing danger of overfitting for high-capacity models. To address these two challenges, we propose a deep multi-task neural network for predictive modeling. Methods The proposed network leverages clinical measures as auxiliary targets that are related to the primary target. The predictions for the primary and auxiliary targets are made simultaneously by the neural network. Network structure is specifically designed to capture the clinical relevance by learning a shared feature representation between the primary and auxiliary targets. We apply the proposed model in a hypertension dataset and a breast cancer dataset, where the primary tasks are to predict the left ventricular mass indexed to body surface area and the time of recurrence of breast cancer. Moreover, we analyze the weights of the proposed neural network to rank input features for model interpretability. Results The experimental results indicate that the proposed model outperforms other different models, achieving the best predictive accuracy (mean squared error 199.76 for hypertension data, 860.62 for Wisconsin prognostic breast cancer data) with the ability to rank features according to their contributions to the targets. The ranking is supported by previous related research. Conclusion We propose a novel effective method for clinical predictive modeling by combing the deep neural network and multi-task learning. By leveraging auxiliary measures clinically related to the primary target, our method improves the predictive accuracy. Based on featue ranking, our model is interpreted and shows consistency with previous studies on cardiovascular diseases and cancers.

recognition, computer vision and healthcare informatics [1][2][3][4]. Compared with linear regression, DNNs have the capability of learning high-level feature representations, rendering better predictions based on those abstract features. This enables DNN to capture the non-linear relations of low-level features, making itself promising in clinical research.
Successful deep neural networks require abundant labeled data for effectively learning useful feature representations. However in clinical practice, collecting labeled data is expensive and time-consuming. As a result, only a limited amount of labeled data are available. Fitting a highcapacity model could potentially overfit the small amount of labeled data.
To avoid overfitting of DNNs, various regularization methods, such as dropout, early stopping and L2 regularization [5] have been developed. In the domain of clinic research, with defining primary targets, we can further mitigate overfitting by leveraging other clinical measures that are generated by the labeling process. As these measures are clinically related to the primary targets, we can integrate them into multi-task framework as regularization that can benefit our model.
For instance, some demographic subpopulations with hypertension are more likely to develop left ventricular hypertrophy (LVH), a form of structural heart damage that results from poor blood pressure control. Left ventricular mass indexed to body surface area (LVMI) is a commonly used method of determining when LVH is present. However, measuring LVMI requires advanced imaging but it is difficult to know which patients should undergo testing, it is challenging to predict as there is no single input features having enough explanation power for LVH. Accurately predicting LVH status for hypertension patients is critical as definitive testing to diagnose LVH, including cardiac magnetic resonance imaging (CMR), is expensive and testing every patient with hypertension would be cost prohibitive.
In this paper, we propose generalized auxiliary-task augmented network (GATAN), extending [6] from regression to general supervised learning tasks. GATAN is a multitask predictive neural network that predicts the primary target and auxiliary targets simultaneously (See Fig. 1). Under the multi-task learning framework, the auxiliary tasks can be viewed as a regularization method as well as implicit data augmentation [5,[7][8][9][10]. GATAN hence can reduce the risk of overfitting. Without a universal definition of inter-task relatedness, GATAN learns task-specific feature representations, as well as a shared representation for all tasks to conceptually capture the relation. The learned representations are then combined together using a weighting mechanism; GATAN makes predictions based on the combined high-level features. Finally, to interpret GATAN, we adopt a heuristic method that analyze the learned weights to rank the contribution of input features.

Generalzied auxiliary-task augmented network
Taking the motivating example in Fig. 1, with LVMI being the primary target, the labeling process would produce additional CMR measures that are also characteristics of heart morphology including septal, posterior and anterior heart wall thickness. These measures are clinically related to LVMI and predictive models can exploit them as auxiliary predictive tasks.
However,the clinical "relevance" is not clearly defined. To circumvent this issue, GATAN models learns a feature representation that can be decomposed into a weighted sum of the shared and task-specific feature representation. The shared representation conceptually models the relevance between tasks. Figure 2 displays GATAN structure. We use feed-forward deep neural network (FDNN) [5] as the building block for GATAN.
Assume that x, y c , y a is a sample with input features x, primary target y c and auxiliary target y a . The shared and task-specific feature representations are learned as follows: where f (·) is modeled by FDNN with multiple stacked hidden layers and non-linear activation (element-wise sigmoid action in our case). These feature representations are then combined to form the final representations h fc and h fa : where {a 1 , a 2 } and {b 1 , b 2 } are the weights that quantify the contributions of h s , h c and h a . Note that in this formulation, h s , h c and h a are of the same dimension. As a side note, another strategy to combine the task-specific and shared feature representations is through vector concatenation h fc = [h c , h s ]. But this approach could introduce more parameters for each h having enough representation power. We hence prefer the weighted sum approach when only limited amount of data is available.
To compute {a 1 , a 2 } and {b 1 , b 2 }, the cosine-distance "cosd" is used: The labeling process also produces other measures that are clinically related to LVMI. We predict these measures as auxiliary tasks in our model for the primary task, and for the auxiliary task, where cosd(v 1 , v 2 ) = v 1 · v 2 /(||v 1 || 2 ||v 2 || 2 ) , || · || 2 is the euclidean norm of a vector. Since we use sigmoid as the activation function, {a 1 , a 2 } and {b 1 , b 2 } are positive and hence proper weights. Note that this strategy biases toward the shared feature representation and forces it to makes at least half contribution (i.e. a 2 , b 2 ≥ 0.5) to the final feature representation for GATAN, displaying the benefits of multi-task learning.
Based on the final feature representation, the prediction y c andŷ a are calculated: where W c and W a are dimension-compatible vectors, h c and h a are bias terms, and l(·) is the link function depending on the specific prediction tasks. When targets are continuous, l(·) is the identity function; for classification tasks, l(·) is the sigmoid or softmax function: The joint objective function is a sum of the loss function for each task: where for notational brevity, we use to represent the set of parameters in the neural network, ω is a hyperparameter balancing different tasks during training. We use ω = 1 in our experiments.
For regression, the loss function is the squared loss: For classification, it is cross-entropy: where we have encoded y c as one-hot vector. Note that GATAN also allows multiple auxiliary targets which can be incorporated into GATAN straightforwardly, and different types of loss functions for different tasks such as one regression task and one classification task.

Feature ranking
Model interpretability is another important aspect in clinical practice. While there are no systematic ways to interpret deep networks, we can extend from linear regression to calculate the contribution of each input feature by back propagating each neuron's contribution through its connections to previous layer of neurons [11].
To see the back-propagation of each neuron's contribution to the target, let us take an example shown in Fig. 3.
be the two weight matrices associated with the last two hidden layers. h j 's contribution can be computed as in linear regression for {j : 1, 2, 3}: Similarly, g k 's (k = 1, 2) contribution C kj to h j is Then the contribution C kjy from g k through h j to y is Since there are three paths from g k to y through h 1 , h 2 and h 3 , the total contribution C ky of g k is We can keep propagating the contribution of neurons to input features to calculate their contributions to the target.
In GATAN, each input features can contribute to the target through the task-specific and the shared network. If C c ky c and C s ky c are the contributions of feature x k through task-specific and shared network to y c respectively, the overall contribution C ky c for x k is just the weighted sum given by C ky c = a 1 C c ky c + a 2 C s ky c , which provides us a heuristic approach for interpreting GATAN, a 1 and a 2 are given by (4).

Datasets and preprocessing
Hypertension dataset The cohort was derived from an NIH-funded study of African American patients with hypertension and elevated systolic blood pressure (>160 mm Hg) at the emergency department of Detroit Receiving Hospital. Previous studies have shown that there are disparities among hypertension patients with some who are at greater risk of LVH. This makes a DNN model that is capable of capturing complex feature interactions promising for predicting LVMI.
In the labeling process of LVMI, other measures that characterize heart morphology such as left ventricular stroke volume to body surface area (LVSVI), left ventricular end-diastolic volume indexed to body surface area (LVEDVI) and septal, posterior and anterior wall thickness, are also produced. These measures are closely relevant with LVMI and provides additional information that can be utilized in GATAN as auxiliary tasks.
The original dataset contains 155 samples and 65 measures. These measures consists of LVMI, 59 input features (demographics, lab results, heart functioning et al.) and 5 other CMR measures as candidates of auxiliary targets. Table 1 and Fig. 4 left panel present basic statistics of targets.
From the perspective of predictive modeling, a model only using lab results and demographics as features (34 in total) is more preferable, as they are more widely accessible and informative for disease progression, compared with the full set of features that contains heart functioning measures. Hence, we also conduct experiments with this set of features.
Wisconsin prognostic breast cancer dataset (WPBC) is a publicly available dataset in UCI repository [12]. The These derived features include the mean value, standard error and largest/worst value for 10 features: radius, texture, perimeter, area, smoothness, compactness, concavity, concave points, symmetry and fractal dimension. The primary target is the "time to recurrence of breast cancer"; the auxiliary target is the recurrence state of being "recur" or "non-recur".
We use Pytorch [14] for building GATAN. In our experiments, each time only one CMR measure is selected as To evaluate performance, the following three metrics are used: • Mean squared error (MSE) measures the predictive error without considering the magnitude of target: • Explained variance score (EVS): where Var(·) is the variance. • Median absolute error (MAE) is a more robust error than MSE that compute the median of absolute predictive errors: Smaller MSE and MAE are better while for EVS, larger is better.

Hypertension data
Using entire feature set We first experiment with the full feature set. The predictive performance on the test data From the table, GATAN with LVEDVI as the auxiliary target (i.e. GATAN-1) achieves the best predictive performance. For example, GATAN-1 improves MSE approximately 3% compared with Lasso; compared with MTLasso, they also performs better with margins 5% (MSE), 13% (EVS), 2% (MAE). We can also see from the table that GATAN provides performance improvements over MLP-4, due to the introduction of auxiliary tasks. This confirms that GATAN benefits from the auxiliary task in multi-task learning as a regularization .
MTLasso also introduces auxiliary tasks. However, MTLasso does not improve over Lasso. MTLasso assumes all tasks share the same subset of effective features. This is too restrictive for LVMI and LVEDVI having the same feature structure. On the contrary, GATAN has less restrictive assumption on defining the clinical "relevance"; GATAN captures the relevance by learning a shared feature representation. This implies that a proper assumption on the task relatedness is crucial for multi-task learning.
Finally, the explained variance score (EVS) is not satisfactory for all models on the testing data. From the definition, EVS is very sensitive to poor predictions. This means that all models fail for some test samples. From the histogram of LVMI (Fig. 4), we see that data might be generated from a multi-modal distribution and all models fail to capture the local data structure.
We further explored the predictive behavior of GATAN and find that models often make poor predictions at the tails of sample distribution (results no shown). For the used hypertension dataset, we find that the Pearson correlation between LVMI and calcium level is 0.79 at the right tail (LVMI >120). A two-tail correlation test shows the Pearson correlation is statistically significant (p-value <0.001). However, the Pearson correlation between LVMI and calcium is 0.00 for the  In previous studies [15,16], it was shown that patients with LVH have strong positive correlation with serum calcium level compared to those without LVH. Our observations are consistent with these findings. This disparity of correlation between LVMI and calcium among the hypertension patients implies LVH prevalence differs among patient subgroups.
Using demographics and lab results only We use the same experiment setup as in the experiment with a full set of features. Table 2 shows the predictive performance with a more limited dataset. Our multi-task neural network (GATAN-1 and GATAN-2) performs better than other models, implying that our strategy of learning high-level feature representations would benefit predictive modeling. However, comparing with the setup of a full feature set, excluding heart functioning measures from the input features degrades model performances, as functional measures are expected to be more informative for predicting LVMI.
Interpreting GATAN Figure 5a and b show the top-20 features from the full set of features with respect to two different auxiliary tasks. Comparing these two figures, we see that the feature ranking in a is approximately matched with that in a. Sex is the most important feature. In the hypertension dataset, the sample mean of male versus female is 95.78 v.s 85.21; the difference between female and male is statistically significant with p-value <0.0001 for a two-sample t-test. From the figure, we also see that other features with significant contributions are functional measures, such as ejection duration, LV ejection fraction and Cornell product (an electrocardiographic predictor of LVH). This is sensible since heart structure and function are inherently related.  The best performance is bolded matched as those in Fig. 5. From the figure, both systolic and diastolic blood pressure are the most important features for predicting LVMI. The relationship between hypertension and LVH was the basic premise of our study. This is not surprising according to [17] that elevated blood pressure corresponds with high LVMI. Moreover, GATAN identifies more subtle relations between lab results and LVMI, including potassium, vitamin D, calcium, diabetes status, bun, renin et al. These top-ranked features accord with previous researches ( [15,18,19]), demonstrating that feature ranking by analyzing the learned weights is a reasonable heuristic for interpreting deep neural networks. Table 3 shows the performance of different models on the WPBC testing data. In terms of MSE and MAE, GATAN achieves the smallest predictive error 860.625 and 23.860 respectively. For the explained variance score (EVS), all models perform poorly. One reason accountable for this phenomenon is that the distribution of primary target "time to recur" is highly right-skewed, making it difficult for models fitting data with a long tail well. Figure 7 presents the top-10 important features for predicting "time to recur", including FNA area, radius and texture et al. This is intuitive as morphological measures are informative about the breast cancer.

Conclusions
In this paper, we propose a deep multi-task neural network, GATAN, for predictive modeling in clinical research. GATAN leverages additional information in the modeling process by introducing clinical measures as auxiliary targets. As a DNN model, GATAN is capable of high-level feature learning, as well as flexibly captures the clinical relevance between the primary and auxiliary targets. As our experiments using two different datasets show, with one auxiliary task demonstrate GATAN can achieve superior performance compared with traditional models when we only have access to a limited amount of labeled data.