Test-Time Domain Adaptation by Learning Domain-Aware Batch Normalization

Yanan Wu1*, Zhixiang Chi2*, Yang Wang3, Konstantinos N. Plataniotis2, Songhe Feng1
1Beijing Jiaotong University, 2University of Toronto,3Concordia University
AAAI 2024 (Oral)

*Indicates Equal Contribution

Abstract

Test-time domain adaptation aims to adapt the model trained on source domains to unseen target domains using a few unlabeled images. Emerging research has shown that the label and domain information is separately embedded in the weight matrix and batch normalization (BN) layer. Previous works normally update the whole network naively without explicitly decoupling the knowledge between label and domain. As a result, it leads to knowledge interference and defective distribution adaptation. In this work, we propose to reduce such learning interference and elevate the domain knowledge learning by only manipulating the BN layer. However, the normalization step in BN is intrinsically unstable when the statistics are re-estimated from a few samples. We find that ambiguities can be greatly reduced when only updating the two affine parameters in BN while keeping the source domain statistics. To further enhance the domain knowledge extraction from unlabeled data, we construct an auxiliary branch with label-independent self-supervised learning (SSL) to provide supervision. Moreover, we propose a bi-level optimization based on meta-learning to enforce the alignment of two learning objectives of auxiliary and main branches. The goal is to use the auxiliary branch to adapt the domain and benefit main task for subsequent inference. Our method keeps the same computational cost at inference as the auxiliary branch can be thoroughly discarded after adaptation. Extensive experiments show that our method outperforms the prior works on five WILDS real-world domain shift datasets. Our method can also be integrated with methods with label-dependent optimization to further push the performance boundary.


Problem Setting

MY ALT TEXT


Unsupervised Domain Adaptation (UDA) allows for the use of unlabeled data from a target domain during training, and knowledge is transferred from the data from source domains. However, it is less applicable in real-world scenarios as it requires repetitive large-scale training for every target domain.

Domain Generalization (DG) operates under the assumption that the prior knowledge of target domains is unknown. It expects that the model trained on a source domain will perform well across all target domains. While DG is more practical, it is often less optimal since it does not adapt to the domain-specific knowledge of the target domains.

In this work, we focus on the problem of Test-time Domain Adaptation (TTDA) or Few-shot TTDA, which somehow combines UDA and DG. It follows the source-free setting as in DG but requires an additional learning phase at test-time for each of the target domain: when an unseen target domain is encountered at test-time, a few unlabeled images are sampled to update the model towards that domain. The adapted model is then used for testing the data in that domain (as shown in the figure above).


Motivation

MY ALT TEXT


Our work is partly inspired by the observation that the weight matrix tends to encapsulate label information, while domain-specific knowledge is embedded within the BN layer. We propose a strategic manipulation of the BN layer to optimize the acquisition and transference of domain-specific knowledge. The BN layer normalizes the input feature followed by re-scaling and shifting using two affine parameters. However, the normalization statistics computed for the target domain under TT-DA can be unstable since we only have a small batch of examples from the target domain. Instead, we propose to only adapt the two affine parameters while directly using the normalization statistics learned from source domains during training. We further use self-supervised learning method to update the affines on unlabeled target data.


Method Overview

MY ALT TEXT


To learn a good initialization for the affine parameters that are suitable to adapt domain-specific information, we employ two stage of learning process. In the joint training stage (a), we train the entire network to learn both label knowledge and normalization statistics by mixing all the source data and performing joint training. During the meta-auxiliary training stage (b), we first obtain the adapted parameters based on the auxiliary loss in the inner loop. Then, the meta-model is updated at the outer loop based on the main task loss computed on adapted parameters. At test-time (c), we simply apply the adaptation step to update the model specifically to an unseen target domain.


Experimental Results

MY ALT TEXT

Comparison with the state-of-the-arts on the WILDS benchmark under the out-of-distribution setting.

MY ALT TEXT

Comparison with the state-of-the-arts on the DomainNet benchmark under the leave-one-out setting.

MY ALT TEXT

Verification of domain knowledge learning. ``No adapt" means the meta-learned \( \gamma, \beta \) are used for all domains without adaptation. ``Not matched" means each target domain randomly uses the adapted \(\tilde{\gamma}, \tilde{\beta}\) from other domains instead of its own. ``Matched" means each target domain uses its own adapted \(\tilde{\gamma}, \tilde{\beta}\).


MY ALT TEXT

t-SNE visualization of features before and after adaptation. Each data sample is represented as a point, and each color corresponds to a class randomly selected from the target domain of the iWildCam dataset.


Citation

@InProceedings{wu2023test,
		   title={Test-Time Domain Adaptation by Learning Domain-Aware Batch Normalization},
		   author={Yanan Wu, Zhixiang Chi, Yang Wang, Konstantinos N. Plataniotis, Songhe Feng},
		   booktitle={AAAI Conference on Artificial Intelligence},
		   year={2024}}

Poster