In Neural MT, lexical features are fed to the network as lexical representations (aka word embeddings) to the first layer of the encoder and refined as propagate through the deep network of hidden layers. In this post we’ll try to understand how the lexical representation is affected as it goes deeper in the network and investigate if it affects the translation quality.
Recently, several studies have investigated the nature of language features encoded within individual layers of the neural translation model. Belinkov et al. (2018) reported that in recurrent architectures, different layers prioritise different information types. As such, lower layers are suggested to represent morphological and syntactic information, whereas the semantic features are concentrated towards the top of the layer stack. In an ideal scenario, the information encoded in various layers should be transported to the decoder whereas in practice only the last layer is used.
Along the same line, Emelin et al. (2019) studied the transformer architecture. In the transformer model, the information proceeds in a strictly sequential manner, where each layer attends over the output of the immediately preceding layer, complemented by a shallow residual connection. For input features to be successfully propagated to the uppermost layers, the encoder must preserve them until they are processed. By retaining the lexical content, the model is unable to leverage its full representational capacity for learning new information from other sources, such as the surrounding sentence context. They refer to this limitation as the representation bottleneck.
Lexical shortcuts (Proposed Method)
To alleviate the above representation bottleneck, Emelin et al. (2019) proposed extending the standard transformer architecture with lexical shortcuts. In the proposed architecture, the embedding layer is directly connected with each subsequent self-attention sub-layer in both encoder and decoder. The lexical shortcuts allow the model to access the relevant lexical information at any point, instead of propagating it upwards from the embedding layer along the hidden states. In an alternative formulation of the proposed model, referred to as feature-fusion, they concatenate the outputs of the immediately preceding layer Hl-1 and the embedding E, before the initial linear projection producing the input for the following layer. For mathematical details of the proposed method refer Emelin et al., 2019.
They evaluated the resulting model’s performance on multiple language pairs and varying corpus sizes, showing a consistent improvement in translation quality over the transformer baseline. Specifically, they trained the models using 5 WMT (Workshop on Machine Translation) datasets - German→English (DE→EN), English→German (EN→DE), English→Russian (EN→RU), English→Czech (EN→CS), and English→Finnish (EN→FI). The transformer model with lexical shortcuts outperforms the baseline model by 0.5 BLEU on average. With feature fusion, they reported even stronger improvements, gaining +1.4 BLEU for EN-DE, and +0.8 for other 4 translation directions. Further, they reported that the gain, adding lexical shortcuts are substantially smaller for transformer-BIG compared to transformer-BASE. One potential explanation of this drop is that the wider model is able to learn additional information.
Adding lexical shortcuts is motivated by the hypothesis that transformer models retain the lexical information within its individual layers, which limits its capacity for learning and representing other types of relevant information. Direct connections to the embedding layer alleviate this by providing the model with access to lexical features at each processing step, leaving the space for incorporating other relevant information types. Furthermore, it is worth mentioning that transformer-BASE added with lexical shortcuts achieves comparable quality scores to the standard transformer-BIG. It is also worth noting that transformer-BASE equipped with lexical connections performs comparably to the standard transformer-BIG, but 2.3x faster to train and decode.