On the Byzantine-Resilience of Distillation-Based Federated Learning
TL;DR: This is an informal summary of our recent paper On the Byzantine-Resilience of Distillation-Based Federated Learning by Christophe Roux, Max Zimmer and Sebastian Pokutta. We analyze the byzantine robustness of FedDistill, a federated learning paradigm where clients share predictions on a public dataset instead of model parameters. Our findings reveal that FedDistill is remarkably resilient to byzantine attacks, in which clients send malicious predictions to disrupt the training process. We introduce two new, more effective attacks in this context and propose a novel defense mechanism to enhance the robustness of knowledge distillation-based FL. Additionally, we offer a general framework to obfuscate attacks, making them harder to detect.
Written by Christophe Roux.
Warmup: Federated Learning
If you are familiar with (byzantine) FL, FedAVG, FedDistill, feel free to skip ahead.
The objective of federated learning (FL) is to centrally train a Neural Network (NN) when the data is distributed across multiple clients without sharing the clients data.
FedAVG
FL has become synonymous with Federated Averaging (FedAVG) [M17] or variants thereof, which work as follows: A server initializes an NN and sends the parameters to the clients, which train the model on their private data and send the updated parameters back to the server. The server then aggregates these updated parameters (e.g., by taking the mean) and sends the updated model back to the clients. This process is repeated until a satisfactory model is trained.
FedDistill
In Federated Learning with Knowledge Distillation (FedDistill), we assume that in addition to the private datasets held by the clients, there exists a publicly available unlabeled dataset. As in FedAVG, the server initializes an NN and sends the parameters to the clients. After training the model locally on their private data, the clients compute their predictions on the public dataset and send them to the server, instead of their parameters. The server then uses Knowledge Distillation (KD) to distill this information into its model by training on the public dataset using the aggregated client predictions as labels. This process is repeated until a satisfactory model is trained.
Byzantine FL
Byzantine FL refers to the setting where a fraction of the clients are byzantine, i.e., they exhibit adversarial behaviour. Typically, it is assumed that the goal of such adversaries is to derail the training process and prevent the server and the benign clients from training a useful model. For FedAVG, such attacks would involve sending malicious weights to the server, for example to disrupt the aggregation step. For FedDistill, byzantine clients can send malicious predictions to the server.
Motivation
One weakness of Federated Averaging (FedAVG) is its susceptibility to byzantine attacks, where clients send malicious updates. In fact, just one byzantine client can completely derail the learning process. This has motivated a long line of research aiming to design byzantine-resilient FedAVG variants.
FedDistill has emerged as a promising alternative to FedAVG which reduces communication overhead, enhances privacy and robustness to non-i.i.d. data. In FedDistill, clients communicate information about the learning task by sharing their predictions on a public dataset rather than transmitting model parameters.
Our observations reveal that vanilla FedDistill shows remarkable resilience to byzantine attacks (see Figure 1). This work aims to take a closer look at FedDistill’s robustness in the byzantine setting.
Figure 1: ResNet-18 on CINIC-10: Final test accuracy of FedAVG and FedDistill, varying the fraction of byzantine clients for two naive attacks. For FedAVG, the byzantine clients simply send Gaussian noise (GN) instead of parameter updates. For FedDistill, they send random one-hot predictions, we refer to this as the Random Label Flip (RLF) attack.
Comparing the attack vectors of FedAVG and FedDistill
Before discussing specific attacks and defences, lets compare how byzantine clients can influence the learning process in FedAVG and FedDistill. We write \(\mathcal{B}\) for the set of byzantine clients and \(\mathcal{H}\) for the set of honest clients.
In FedAVG, the server updates the parameters by taking the mean of the clients parameters, i.e., \(\bar{w}\gets \frac{1}{N}\sum_{i=1}^N w_i = \underbrace{\frac{1}{N}\sum_{i\in \mathcal{B}} w_i}_{\text{threat vector}} +\frac{1}{N}\sum_{i\in \mathcal{H}} w_i.\) It is easy to see that just one byzantine client can arbitrarily perturb \(\bar{w}\).
In FedDistill, the server updates the parameters by distilling the knowledge from the clients based on their predictions on the public dataset \(D_{p}\) , i.e., by solving the following optimization problem
\[\min_w \sum_{x\in D_p} \mathcal{L} (h(x,w),\underbrace{\bar{Y}(x)}_{\text{threat vector}}), \tag{$\mathcal{P}_\text{distill}$}\]where \(h\) is the classifier, \(\mathcal{L}\) is a loss function such as the cross entropy loss and \(\bar{Y}\) is the mean of the clients predictions.
Looking at the threat vectors, two clear advantages of FedDistill arise. First, the byzantine clients can only influence the server parameters in FedDistill indirectly via \(\mathcal{P}_\text{distill}\) as opposed to the direct impact in FedAVG. Second, the predictions lie in the probability simplex, a bounded set, limiting the impact of the attack. In fact, we show that the difference between the gradients based on the perturbed prediction \(\bar{Y}\) and the gradients based on only honest predictions scales linearly with the difference between the honest and the perturbed predictions. This implies that the distance between a stationary point of (\(\mathcal{P}_\text{distill}\)), which is perturbed by byzantine predictions, and a stationary point of the honest version are bounded.
Proposing new attacks
In order to find out if FedDistill really is resilient to byzantine clients or if this is just an artifact created by weak attacks, we need to try and find stronger attacks. The prior literature only considers simple label-flipping attacks such as the one we used in Figure 1.
In general, the goal of the byzantine clients is to create predictions which prevent the server from learning a useful classifier. The server uses the predictions made by the clients as a target. Therefore, the goal of the byzantine clients can be seen as choosing their predictions such that the resulting mean prediction \(\bar{Y}\) is as different as possible from the mean of the honest predictions. The question is: How do we quantify this difference or similarity between predictions?
Figure 2: Attack procedures for a three-class classification problem with four honest and three byzantine clients. The left part of the figure shows the computation of the honest mean. LMA (upper right) assigns probability one to the least likely class based on the honest mean \(\bar{Y}_{\mathcal{H}}\) and CPA (lower right) assigns probability one to the class that is least similar to the most likely class of \(\bar{Y}_{\mathcal{H}}\), according to the similarity matrix \(C\).
The first attack, which we call the Loss Maximization Attack (LMA) simply measures the difference between the predictions by using the loss function \(\mathcal{L}\) the server uses for training. It chooses the byzantine prediction which maximizes the loss of \(\bar{Y}\) given the mean honest prediction. It turns out that for typical loss functions such as the cross-entropy loss, this optimization problem can be solved analytically and corresponds to predicting the class that is assigned the smallest probability by the mean of the honest clients, see Figure 2.
The second approach, the Class Prior Attack (CPA), evaluates differences based on the semantic similarity between classes; for example, the class “dog” is more similar to “cat” than to “ship.” This similarity information is captured in a similarity matrix that assigns a similarity score to each pair of classes. In our experiments, we create such a similarity matrix by computing the covariance of the logits of a pretrained model, see Figure 2.
Both CPA and LMA are significantly more effective than RLF and heavily impact the final test accuracy, see Figure 3. Note that here lower accuracy is better, since the byzantine clients try to disrupt the training process.
Figure 3: ResNet-18 on CINIC-10: Test accuracy evolution over communication rounds when attacking FedDistill with 9 byzantine out of overall 20 clients.
Proposing new defences
The straightforward way to make FedDistill more robust is to use a robust aggregation method. One such method suggested in the literature is Cronus [C19], a filtering-based approach based on the high-dimensional robust statistics literature. But these approaches are limited since they treat each prediction individually, instead of using the available information from the clients’ predictions on other samples within each communication round as well as information from past communication rounds.
In order to leverage this additional information, we propose ExpGuard, which is used in combination with a robust aggregation method. ExpGuard tracks how much each clients predictions differ from the robustly aggregated prediction and computes a weight for each client. The predictions are then aggregated by computing the weighted mean of the client predictions.
It turns out that ExpGuard improves the resilience of all aggregation methods, often so much so that the accuracy is almost as good as in the honest setting.
Attack obfuscation: Hiding In Plain Sight
While it’s encouraging to see ExpGuard perform so effectively, this outcome is expected since our attacks are maximally disruptive, without considering how detectable they are. In fact, all of the attacks we discussed assign probability one to a single class and zero to all others. However, as defense methods are introduced, a tradeoff emerges between attack strength and detectability.
To better navigate this tradeoff, we propose Hiding In Plain Sight (HIPS), a method designed to obfuscate attacks and make them harder to detect by limiting how much the byzantine predictions can differ from the honest ones. In practice, we achieve this by constraining the byzantine prediction to remain within the convex hull of all honest predictions.
Figure 4: Attack spaces in \(\Delta_3\): The blue dots represent the predictions by the honest clients and \(\bar{Y}_{\mathcal{H}}\) is their mean. The attack space is highlighted in yellow. \(\bar{Y}_{\mathcal{B}}\) is the byzantine prediction, and \(\bar{Y}\) denotes the mean of \emph{all} clients for \(\alpha=0.5\). The red line joining them represents the mean \(\bar{Y}\) corresponding to different \(\alpha\in [0,0.5]\).
As expected, combining our attacks with HIPS greatly improves their effectiveness against the various defence methods. However, since it reduces the attack amplitude, HIPS makes attacks less effective if there are no defences. In fact, using the mean is often slightly more resilient towards HIPS than the other defences. Since we do not know in advance which attack strategy byzantine clients will use, the best defence is the one that experiences the smallest drop in accuracy even against the most effective attack. This ensures that, regardless of which attack is employed, the defence performs reliably and minimizes the worst possible impact on model accuracy. Since ExpGuard does not have a significant impact on the accuracy of FedDistill in the absence of byzantine clients, does not require hyperparameter tuning, and can be efficiently computed, it is a promising method to reduce the impact of arbitrary failures when using KD-based FL methods.
References
[C19] Hongyan Chang, Virat Shejwalkar, Reza Shokri, and Amir Houmansadr. Cronus: Robust and Heterogeneous Collaborative Learning with Black-Box Knowledge Transfer. arXiv: 1912.11279 [cs, stat], December 2019.
[M17] H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Agüera y Arcas. Communication-Efficient Learning of Deep Networks from Decentralized Data. arXiv:1602.05629 [cs], February 2017.