In this paper, we propose ResTR, and its framework is provided in Figure. As shown in Figure, ResTr consists of two different branches to extract semantic features and structural features, respectively. Furthermore, we design a joint loss function to optimize the end-to-end model. As shown in the figure below:
2.1 CNN Branch
As a CNN model widely used in image classification tasks, ResNet is used to extract local features. We used the core of ResNet's architecture, the residual structure other than the fully connected layer. The residual structure can solve the gradient and performance degradation problems while maintaining network complexity and depth, and achieve very significant classification results. Unlike other traditional CNNs, ResNet uses stride = 2 convolution for downsampling and replaces the fully connected layer with a global average pool layer. Feature extraction is then performed by stacking 3X3 convolutions, and training is accelerated using Batch Normalization (i.e., dropout is discarded). Figure 3 shows the network structure of ResNet, which is conventionally divided into 5 parts.
2.2 Transformer Branch
We use the transformer to extract global features thanks to its ability to model the long-range dependencies between the input sequence elements. Unlike CNN branches, the input image needs to be divided into separate patches with size of P × P[ViT]. An image can be obtained with 16 patches, which are further flattened into 1D format and then embedded into the D dimension by linear layers, which are further flattened into 1D format and then get embedded via a linear layer into D dimensions. For the patch embedding the position information of the image sequence is added using the standard learnable 1D position embedding. The resulting vector sequence is then fed into the Transformer encoder. More specifically, the encoder consists of alternating Multihead Self-Attention (MSA) and MLP block layers. LayerNorm (LN) is applied before each block and residual concatenation is applied after each block.
2.3 Self-Attention Overview
The core of the Transformer structure is the SA mechanism, which is similar to the idea of attention. In the image classification task of this paper, SA can be understood as computing the correlation between each pixel point. The specific computation process is following: an input feature map with height H, weight W and channels C. The output of a self-attention layer is computed using the following equation:
actual process of getting such affinities, the calculation is very expensive. Inspired by axial attention[Axial-DeepLab], self-attention is decomposed into two self-attention modules. The first module performs self-attention on the feature map height axis and the second one operates on the width axis. Axial attention can effectively capture non-local contextual information and is computationally more efficient. It can also effectively embed location information and capture long-range dependencies between feature mappings. After that, the updated self-attention mechanism along with width axis can be written as:
Eq. 2 describes the axial attention applied along the height axis of the tensor, and a similar formulation is also used to apply axial attention along the width axis after. Figure 2 illustrates the calculation process of axial attention along two directions.
Dataset Details
The main manifestations of RVO fundus are tortuous dilatation of the affected veins, flame-like haemorrhage along the retinal veins[7]. Each eye of the patient is treated as an independent classification sample. Multicolor(MC) image is one of the new technologies based on confocal scanning laser ophthalmoscope (cSLO), which can obtain multiple modal images with different wavelengths. We collected the results of the MC image examination from 29 patients as an internal dataset with a total of 220 images (some cases with missing images). The number was increased to 1320 by image enhancement strategies, including random cropping, rotation, horizontal flipping and color jitting. The ophthalmologist manually labelled them as RVO or not. These labels are treated as the ground truth for validating our algorithm. As shown in the figure below:
Experimental Setting
We implement the proposed framework and perform experiments using the PyTorch library. All experiments are performed on a GPU cluster with four NVIDIA GeForce RTX 3090 GPUs and each with 24 GB of memory. The optimization process is run for 50 epochs. We set the learning rate as 0.0001, and the batch size is set as 12. All methods are optimized by the Adam optimizer. The cross-entropy function is selected as the classification loss function.