Deep Learning with Diverse Objectives Improves ARDS Prediction

When training neural networks (NNs) on time-series inpatient data, as the number of outcomes predicted diversifies, the NN both generalizes better on external validation and reaches higher performance in similar numbers of training epochs. We demonstrated this in the context of predicting decompensation in Acute Respiratory Distress Syndrome (ARDS). The NN outperformed gradient boosted trees, achieving an area under the receiver operating characteristic of 0.86 on an external hold out test set of hospitals not included in the training set. We estimated real world benefit by comparing mortality rates between similarly at risk patients when diagnosed before or after the time of algorithm evaluation. We showed that, among similarly at risk patients, earlier diagnosis of ARDS nearly doubles the rate of in-hospital survival. Using cluster analysis of the algorithm’s internal representations, we identified distinct ARDS sub-groups, some of which had similar mortality rates but different clinical presentations.

. Flow Chart of Patients. Among 7 hospitals, 40,703 patients met 3 criteria: 1. admission within the date range 4/20/2018 -3/17/2021, 2. length of stay within the range of 2 hours -3 months and 3. availability of basic vitals (blood pressure, heart rate, temperature, respiratory rate and peripheral oxygen saturation) and labs (complete blood count, CBC, and basic metabolic panel, BMP) in the EHR. These patients were separated into training, validation and test sets based on their hospital sites for external validation. The test set was limited to those patients with the required features listed in Table 1 consisting of age, sex and basic labs plus CBC with differential.

Input Features
Model inputs were a defined set of data types, or features, across all hospitals regardless of the data availability at a particular hospital. Table 1 includes all the features used to train the Machine Learning Algorithms (MLAs) in this study. The required features are the subset of features including age, sex and basic labs plus CBC with differential used to determine the time the algorithm makes it's prediction. These data values were next organized into a matrix with features along the first dimension (rows) and discrete time in 20 minute intervals along the second dimension (columns). The first column, column index 0, contains the first time-point of any vital or lab measurement and was considered to be the start of care. The first row was normalized age. The second and third rows were binary indicators for male and female sex. The remaining set of rows were the time varying features and their corresponding mask to distinguish missing values from actual zeros Table 2. Table 1. Input Features to the Machine Learning Algorithm. Features with an asterisk (*) are "required" features.

Demographics
Age* Sex*

Other Measurements
Systemic Inflammatory response syndrome (SIRS) To normalize the features, we did a coarse approximation of the mean and standard deviation based on the normal range of these features in the lab reports. The center of the normal range was used as the approximate mean value and the half the range was used to approximate the standard deviation. If a feature was missing or not measured, it is set to 0. To let the model distinguish between null values and real values, a new set of features representing the availability mask was vertically appended to the matrix. Each feature row had a corresponding binary mask vector which contains 0s and 1s, representing null values and non-null values respectively. During batch training, these matrices were zero-padded on the left side into equally sized tensors of [batch size, 58 features, 64 timesteps].

Model Output and Targets
Additional target labels were chosen for the model that are distinct from ARDS, yet clinically related to it such that a collective representation in the NN is justified. These labels are shown in Table 3. The model was trained to predict the following target labels, also referred to as outcomes, using a binary cross entropy loss function: Table 3. Clinical Outcomes used as Target Labels for the Machine Learning Algorithm. Thirteen output labels mapped to their respective definition. These output labels will be used throughout the paper hereafter.

Timing of Algorithm Evaluation
For simplicity we evaluated the algorithm at a single point in time. This time is two timesteps after the first time at which all required features have been measured at least once. At this time, which we call the Algotime, the model predicts all of the target outcomes of Table 3. In the training and validation sets, we used the required features to determine Algotime, but in the case of missing features, it defaults to 8 hours after admission. In the test set, we included only the patients that have all of the required features. As can be seen in Figure 2, on average, the Algotime occurred 31 hours after admission and ARDS was clinically diagnosed 139 hours after admission, hence the average number of hours between the Algotime and the clinical diagnosis of ARDS was 108 hours.. It should be noted that the performance statistics reported in this paper corresponds to the Algotime, not the time of clinical ARDS diagnosis or the end of the stay.

Benefit Estimation
Thus far we have defined our machine learning objective as the early prediction of conditions. However, more important than the early prediction, is the impact of early prediction on patient health outcomes. To approximate this improvement in the outcome of mortality, we compared mortality rates between patients who received early and late diagnosis of ARDS. We defined early diagnosis and late diagnosis as early or late relative to when the algorithm was likely to provide its prediction. In other words, "early" and "late" were defined relative to Algotime, the minimum set of features became available in the EHR. If a clinical diagnosis of ARDS, using the ARDS-1 definition of ARDS ICD and SpO 2 below 91% , was made before Algotime, this was considered an early diagnosis of ARDS. If a diagnosis was made after Algotime, this was considered a late diagnosis of ARDS.

Machine Learning Models
We used RNN as the main ML model for our research. RNN is a class of artificial neural networks in which connections between nodes form a directed graph along a temporal sequence. RNNs can use their internal memory to process variable length sequences of inputs. The network is capable of learning a mapping function from the inputs over time to an output. It can even learn temporal dependence from the data. All these properties make RNN a well suited model for time series data like the one that is used in this study. The model schema of the RNN used in this research is presented in Figure 3. The RNN was implemented with the PyTorch package version 1.40 in Python 3.6 (26). For the RNN, the sequence module that was used was a 4-layer gated recurrent unit (GRU) (27) with 128 hidden units. Before the sequence of vectors was fed to the GRU, it passed through a normalization layer: (1) is a normalization function that learns the parameters mean , standard deviation , µ σ scaling a and translation factor b used to normalize the sequence of vectors containing inputs age, sex boolean, vitals, labs, and systemic inflammatory response syndrome (SIRS) score. before entering the RNN. A soft attention module was used to assign scores to each timestep in the sequence. Scores are intended to be positively correlated with the importance of its respective timestep. A weighted sum of the sequence's hidden activations was called the context vector. We concatenated the context vector to the final GRU embedding and passed this to a 2-layer feed forward neural network for classification. The intermediate layer before the output logits was a 128 dimensional representation of each patient, referred to as the penultimate embedding. Similar to Bahdanau et al. 2014 (28), the score equation (2) of the attention neural network was parameterized by a feed forward neural network of the form: where tanh and Prelu denote the hyperbolic tangent function and parameterized rectified linear unit non-linearity functions, respectively; h l denotes the last hidden activation in the GRU; h i denotes each hidden activation in the sequence; i denotes the timestep; [,] denotes concatenation of separate vectors into one, and K, W a , W b denote learned parameters of the neural network. The whole GRU-RNN, attention module and classification module were end-to-end differentiable, which enabled optimization from input to output. The attention neural network was a mechanism of the RNN that allowed for higher quality learning. Instead of summarizing a time series of vectors, the attention neural network assigned each vector a score according to how important the vector was in allowing the model to make a prediction. In this way, the attention network mechanism allowed the RNN to focus on specific parts of the input, therefore enabling improved model performance. Each point in the RNN model schema is representative of a neuron. At each layer, the RNN combined the information from the current and previous timesteps to update the activations of the deepest GRU hidden layer. The last of and deepest RNN activation is concatenated from the context vector provided by the attention network. The context vector is an importance weighted average of the deepest layer activations generated by the attention neural network. This concatenated vector is passed through 2 fully connected (FC) layers to generate an output (e.g.. prediction of ARDS onset). With this RNN schema, the model was trained to predict several target labels simultaneously and to evaluate a loss function based on all targets. We implemented a deep learning method where a single network was trained to output one logit per label utilizing a binary cross entropy loss function (29). Each output logit was independently passed through a sigmoid activation function to give the final multi-label output (30). Early stopping was used and was based on the ARDS-1 validation performance measured in area under the Receiver Operating Characteristic Curve (AUROC). To explore the relationship between the objective function's diversity of targets and final model performance, the lowest 2 AUROC targets were removed successively from each version of the RNN such that the RNN was trained using 13, 11, 9, 7 and 5 targets.
Tree based models frequently outperform deep learning in many clinical applications (31). To ensure this is not the case here, for comparison, XGBoost (XGB) models for each of the target labels were trained using XGBoost v0.81 in Python 3.6 package (27,28) and the same feature matrix as the RNN model. The XGB models were trained in a 1 versus all fashion for each target.

Clustering
In order to explore the representations used by the model and to reveal distinct phenotypes among ARDS patients, we collected the 64 dimensional activations produced by the first fully connected layer (FC1) as a compressed representation, or embedding, for each patient. To visually display these embeddings, we used Uniform Manifold Approximation and Projection for Dimension Reduction (UMAP) to reduce the embedding to two dimensions (34). We then used k-means clustering to group each ARDS patient into their unique cluster. Clusters of embeddings are not inherently interpretable; therefore, to confer clinical interpretability to these clusters, a new XGB model was trained using original features as inputs and cluster labels as output. The XGB model's SHAP (SHapley Additive exPlanations) plots (35) were used to interpret the clinical properties of each cluster.

Statistics
To compare different algorithms and training objectives, we computed the 95% confidence interval around the AUROC using the bootstrapping method (36,37). These intervals are with respect to the test set (n=951).

Ethics
All patient data was de-identified in compliance with the Health Insurance Portability and Accountability Act (HIPAA). This study was considered to be of minimal risk for human subjects as data collection was passive and did not pose a threat to the subjects involved. The project was approved with a waiver of informed consent (protocol study number 20-DASC-122) by an independent institutional review board, Pearl IRB (Indianapolis, Indiana; www.pearlirb.com).

Comparison of RNN model with XGB
The XGB and RNN models were compared across all 13 outputs

Benefit of Diverse Training Objectives
An intermediate number of output targets between 1 and 13 were also used to re-train the RNN. Figure 4 shows the maximum AUROC with different subsets of the 13 outcomes used as training targets. For most targets there is a general trend towards overall improvement of AUROC. For ARDS-2 and ARDS-5 the improvement is monotonic with the increasing number of targets. This demonstrates that there is some underlying dependency between some of the labels.

Figure 4. Model Performance Varies with the Number of Outcomes Predicted During Training.
Validation AUC plotted against the number of targets in the RNN output (e.g. RNN9 refers to an RNN with 9 outputs). From right to left, the worst performing 2 targets in terms of AUROC are removed to train the next RNN with a smaller number of targets.

Training on Diverse Objectives Converges in a Comparable Number of Epochs
Learning quality and efficiency of single versus multiple outcome models were evaluated in terms of the rate of improvement of AUROC on the validation set per each stochastic gradient descent training epoch. We compared the rate of learning between RNNs trained with single targets and RNNs trained with multiple targets, to demonstrate that training on diverse learning objectives does not empirically require longer durations in training compared to single learning objectives. The rate of learning was measured as the AUROC of the validation set on for each epoch. In Figure 5 the plots of the AUROC of three separate randomly initialized training episodes for 15 epochs are shown for ARDS-1 and ARDS-2. For these two outcomes the time to reach the maximum validation AUROC in terms of the number of epochs is comparable between single and multiple target models.  Figure 6 shows the results of applying the K-means clustering algorithm to find clusters of patients as mentioned in the machine learning models section. The K-means algorithm was set to identify 3 clusters as this was the number of clusters that was observed from the visualization of the UMAP results on the output of the FC1 layer in the neural network. In Figure 7, we see that two clusters, A and B, have similarly high mortality rates but different clinical presentations. Cluster B was characterized by features that clearly signal suspicion for ARDS, such as low oxygen saturation, especially to levels less than 91%, neutrophilia, fever, and tachypnea. Cluster A however had a surreptitious presentation, oxygen saturation that is moderately low but higher than 91%, and leukocytosis present without fever. The model predicted for these patients that although desaturation below 91% has not happened yet, desaturation below 91% is likely to occur in the future.  Fig. 7. Unlike the other features, names ending in future are not inputs to the models, but rather components of the target outcomes for these patients. For example, SpO2 < 91 future indicates that these patients will desaturate in the future after Algotime. (A) considered the cluster A as positive class and other two clusters as negative class; (B) considered cluster B as positive class and other two clusters as negative class; (C) considered cluster C as positive class and other two clusters as negative class.

Benefit Estimation
From our benefit estimation case study, we found that the mortality rate for patients who were diagnosed early with ARDS was 14/266 (5.26%), whereas for those diagnosed late with ARDS was 116/995 (11.66%). The Fisher exact test statistic value was 0.002 for the 6.3% mortality benefit of early diagnosis. For reference, the baseline mortality of patients without ARDS was 656/39442 (1.66%) and of patients with ARDS was 130/1261 (10.309%).

Conclusion
In this study, we described the development of a deep learning model for predicting multiple outcomes simultaneously using the same set of input features. We showed that the RNN13 model, trained to predict 13 outcomes simultaneously generalizes better on most outcomes than XGB models trained to predict individual outputs. We showed that this improvement is proportional to the number of targets predicted by the RNN. Additionally, the RNN13 generalized better in ARDS prediction and in some cases, reached its highest validation AUROC performance in fewer training epochs. This reinforces our conclusion that training the RNN model on a larger set of outcomes improves generalization. We hypothesize that diverse training objectives generalize better in part due to parameter sharing, which has a regularizing effect, and information sharing across outcomes, which learns richer representations (38).
We used RNN in this research because of its ability to use its internal memory to process variable length sequences of inputs, learn temporal dependence from the data, and share representations for an arbitrary number of outputs. We used a generic RNN with 4 GRU layers, an attention module, and 2 fully-connected layers for all numbers of outputs. We experimented with various RNN architectures, varying parameters such as number of layers and hidden units. From our light grid search, we found that the RNN model architecture used in this paper performs best for our use case. Additionally, the attention module seems to be an important part of the architecture in making the prediction because without it, the performance of the RNN dropped significantly.
To compare RNN with other algorithms, we used XGB because of its ability to handle missing or null values and its current dominance in industrial applications. Because EHR data often has a high level of missing values due to variability in data acquisition and recording habits in the live clinical environment, this quality is appealing. We trained multiple XGB models separately on the same input to classify different outcomes independently. We performed a grid search for hyperparameter optimization, tuning parameters such as tree depth and learning rate.
We also demonstrated an application of cluster analysis to probe deep learning models for clinical insights. Our analysis of the total ARDS population uncovered 3 distinct populations, two of which have similarly high mortality rates but different clinical presentations. One population of ARDS patients had apparent signs and symptoms of ARDS, while the other had much less obvious signs, yet similar future trajectories into respiratory decompensation and death. Recent studies have corroborated similar results in the covid-19 population in which two distinct phenotypes of ARDS were found with similar respiratory dynamics but 2-fold difference in odds of 28 day mortality (39). With the methods outlined in this study, phenotype discovery would be an additional benefit automatically applied to an arbitrarily large number of outcomes predicted.
To connect our machine learning findings with real world clinical effects, we compared the mortality rates between patients diagnosed earlier with ARDS and patients diagnosed later with ARDS relative to the algorithm's prediction time. Our estimation showed that the mortality rate in the population diagnosed early with ARDS was almost half the rate in the population who were diagnosed late with ARDS.

Limitations and future research
This study has several limitations. In many hospital systems, radiology images and radiology reports are kept in a separate software system from the EHR. Ideally we would prefer to confirm ARDS ICD codes by verifying the presence of bilateral lung infiltrates on chest imaging. Our inputs only included demographics, vital signs, and lab information. Future work should therefore incorporate EHR as well as imaging data. Our dataset spans the ED, inpatient and ICU settings and prescribes a single early time point for prediction. This could be one factor in the low AUROC for Sepsis predictions, which prior studies have shown to be reliably predictable in the ICU setting (3,8). This discrepancy warrants future investigation. Additionally, we did not have reliable data on race and ethnicity of the patient population. Future studies should explore the correlation between "early" or "late" diagnoses and the onset of standard ARDS criteria, this may provide evidence to suggest why ARDS diagnosed early might benefit from improved mortality, such as due to diagnosis earlier in the disease progress. Furthermore, this is a retrospective study, therefore, we are not able to determine the performance of our algorithm in a prospective clinical setting. Prospective validation is essential to determine how clinicians will respond to predictions of various outcomes. It is also important to determine whether our predictions can impact patient outcomes or resource allocation. Our work here is meant to serve as a reference for future research directions in establishing the most beneficial role for MLAs in the healthcare ecosystem and expanding the capabilities of machine learning in healthcare. Figure 1 Flow Chart of Patients. Among 7 hospitals, 40,703 patients met 3 criteria: 1. admission within the date range 4/20/2018 -3/17/2021, 2. length of stay within the range of 2 hours -3 months and 3. availability of basic vitals (blood pressure, heart rate, temperature, respiratory rate and peripheral oxygen saturation) and labs (complete blood count, CBC, and basic metabolic panel, BMP) in the EHR. These patients were separated into training, validation and test sets based on their hospital sites for external validation. The test set was limited to those patients with the required features listed in Table 1 consisting of age, sex and basic labs plus CBC with differential.      Clusters Represent Distinct ARDS Phenotypes. SHAP plots for the top 15 features for 3 different XGB models trained using the input features to predict each of the 3 clusters of Fig. 7. Unlike the other features, names ending in future are not inputs to the models, but rather components of the target outcomes for these patients. For example, SpO2 < 91 future indicates that these patients will desaturate in the future after Algotime. (A) considered the cluster A as positive class and other two clusters as negative class; (B) considered cluster B as positive class and other two clusters as negative class; (C) considered cluster C as positive class and other two clusters as negative class.