Towards Federated Learning at Scale:System Design (Reading Notes)

13 minute read


Now, Google has implemented the first product-level Federated Learning System and published the paper “Towards Federated Learning at Scale: System Design.” The paper further introduces the system design of federated learning and describes the design philosophy and existing challenges of this system. Moreover, Google put forward his solution.

Deepmind research scientist Andrew Trask said on Twitter: “This is one of the most exciting papers of the year of 2019. Google has announced how they can implement scalable federal learning solutions on tens of millions of mobile phones.

The system described in this project is the first Production-level federated learning system implementation, with a primary focus on running the Federated Averaging Algorithm on mobile phones. The goal is to extend Google’s systems from federated learning to Federated Computing. However, the system is not limited to using TensorFlow for machine learning calculations, but through a MapReduce-like workload. One area of the application we see is Federated Analytics, which allows us to monitor statistics for large-scale clustered devices without having to log raw device data to the cloud.

Federated Learning Procedures

As shown in the following figure, the federated learning process can be divided into three phases, namely Selection, Configuration, and Reporting.

In the Selection phase: A device that meets certain conditions will request the server to indicate that it can participate in this round of training. After receiving the request, the server will select a part of the device to participate in this round of training. If some devices are not participating in this round of training, the server will ask them to re-request after a period of time, and the server will consider the factors such as the number of participating devices and the timeout period. This round of training will only succeed if enough equipment is available before the timeout to participate in this round of training.

In the Configuration phase: The configuration of the server is mainly the way the server selects the model integration, and the configuration of each device is mainly that the server will send the specific FL task and the current FL checkpoint to each device.

In the Reporting phase: The server waits for each device to return a full result. When the device returns the result to the server, the server aggregates using the aggregation algorithm. Then, the device is notified of the time of the next request. If there are enough device results to return the result before the timeout, the current training is successful. Otherwise, the current training fails.

Throughout the system, there is Pace Steering, which manages the connectivity of the device. For small-scale FL training, Pace Steering guarantees that enough equipment is involved in each round of training. For large-scale FL training, Pace Steering randomizes the request time of the device, avoiding a large number of devices requesting at the same time and causing problems.


The key issue in the design of federated learning system infrastructure is to focus on asynchronous or synchronous training algorithms. Many of the results of deep learning have adopted asynchronous training methods before, but recently, there has been a trend of using large-scale synchronous training. Considering ways to enhance privacy in federated learning, including differentially privacy policies (Mchahan et al., 2018), these methods require some post-synchronization concepts on fixed devices, allowing the server side of the learning algorithm to consume only a large number of users. Simple aggregation of update information.

Therefore, Google researchers chose to use the synchronous training method to run the large-scale SGD algorithm and the federal average algorithm. This is the main algorithm in our production operation. The algorithm code is as shown below:

represents all nodes, is the size of the local minibatch, and is the number of rounds of local training, and is the proportion of nodes selected in this round. The algorithm is divided into server level and end level. At the server level: First, the server selects a specific node to participate in the current round of training. Then the server transmits the current model to each node. After the node returns the model result of the training, the server weights, and averages each parameter. At the end level: First, the node divides the local data into a set of batch size B. Next, the data is used for batch training. The primary method used is the gradient descent method, and finally, the updated model is returned to the server, behind In the experimental phase, there is a comparative baseline, FederatedSGD. When C=1, the algorithm is called Federated SGD.

The system described in the paper uses TensorFlow to train deep neural networks to train data stored on mobile phones. The federated learning average algorithm is used to combine the training data weights in the cloud, construct a global model, and push back to the mobile phone to run the reasoning process. The implementation of secure aggregation ensures that personal updates from mobile phones cannot be peeked globally.

The paper points out that the system can solve many practical problems, solve the availability problems of devices related to local data distribution in complex ways (such as time zone dependence), deal with unreliable device connections and execution interruption problems, and on devices with different availability. Scheduling problems with lock-steps, as well as limited storage space and computational resources. These issues can be resolved at the communication protocol, device, and server levels.

The data of the federated learning system on the device is more relevant and more sensitive to privacy than the data present on the server. Currently, federated learning is mostly used to supervise learning tasks, usually using tags that are inferred from user activity (such as clicks or typing).


To evaluate the effectiveness of the federated learning system, the data set used by the project can be divided into two parts, a common data set and a real-world data set.

The generic data set consists of the MINIST and Shakespeare corpus, and the author divides each data set into independent and identical distributions and non-independent and identical distributions. The author built a multi-layer perceptron and CNN to train the MINIST dataset and built LSTM to train the Shakespeare dataset. The authors found that when , increasing the degree of parallelism (increasing the number of nodes trained) can effectively reduce the number of rounds of communication. Increasing the local training amount of each node can better reduce the number of rounds of communication, and the FedAvg algorithm is better than the FedSGD algorithm. However, excessively increasing the number of local training rounds often does not lead to better results.

The real dataset is the CIFAR dataset and the social network dataset. The author built a CNN model to train the CIFAR dataset and built the LSTM to train the social network dataset. It can be found that FedAvg trains fewer rounds when the same effect is achieved, and the final effect of FedAvg is better than the baseline FedSGD.


On-device item ranking)

A common use of machine learning models in mobile applications is to select and sort items from inventory on the device. For example, an app can expose search settings for information retrieval or navigation in an app. Sorting search results on the device eliminate costly calls to the server (which may be due to latency, bandwidth limitations, or high power consumption), and any potential private information about search queries and user choices remains on the device. on. Each user’s interaction with the ranking feature can be used as a tag data point, and the interaction information of the user with his or her preference can be observed in the fully sorted item list.

Content suggestions for on-device keyboards

The federated learning system can increase the value to users by providing recommendations for relevant content entered by users. Federated learning can be used to train machine learning models to trigger suggestion functions and rank items that can be suggested in the current context. Google’s Gboard mobile keyboard team is using this federated learning system and adopts this approach.

Next word prediction

Gboard also uses the federated learning platform to train recurrent neural networks (RNN) for next word prediction. The model has about 1.4 million parameters. After 5 days of training, after processing 600 million sentences from 1.5 million users, it achieves convergence after 3000 rounds of joint learning (about 2-3 minutes per round). The model increases the maximum recall rate of the baseline n-gram model from 13.0% to 16.4%, and its performance is comparable to that of the 120-step server-trained RNN. In real-time comparison experiments, the performance of the joint learning model is better than that of n-gram and server-trained RNN models.

Future research direction

State analysis

There may be system crashes and other problems in the training of federated studies. Since most of the activity in the training takes place on the device, the server does not have the authority to control or touch these activities, so if a crash or other problem is discovered, the server cannot determine what went wrong. The system needs to analyze the cooperation between the device and the server. In each training, the device needs to record the activity and health parameters, and record each activity in each round of training. These data are often not related to privacy so that they can be uploaded to the cloud, and the server can collect similar data, such as how many devices are connected or rejected. By analyzing the data, we can understand what happened, and it happened. What went wrong and proposed a solution.

Secure aggregation

In order to further ensure the privacy of each device, a secure aggregation method can be adopted. Security aggregation is the reporting phase of federated learning. It consists of four phases and three phases. The first phase is the preparation phase, in which the device generates information to be shared. If a device is disconnected at this stage, Then the result of this device will not be aggregated into the final result; the second phase is the commit phase, in which the device encrypts its own results and uploads the encrypted results to the server; the final phase is the termination phase Each device transmits decoding information to the server, and the server decodes and aggregates according to the decoding information.

Security aggregation increases the computational complexity at the server level, which limits the number of devices involved in training. To solve this problem, the authors used a secure aggregation on each aggregator to aggregate the results of the device responsible for the aggregator. An intermediate value and the primary aggregator re-aggregates the intermediate values so that this problem can be solved.

Developer tools and workflows

Federated learning faces many challenges compared to traditional model training. First, since the server does not know the training data of each node, specific tools are needed to perform pre-training and simulation using proxy data when modeling and initializing. Second, federated learning models are not interactive, but need to be pre-compiled and deployed to the server. The resource consumption and scalability of the final model must be pre-tested. For these reasons, the authors have designed a set of Python interface tools and workflows to help developers solve these problems.

In the model design and simulation phase, engineers can use tools and library files to design federated learning tasks, build models and pre-train, and simulate the entire training process to generate parameters that can be used as initialization parameters for formal training.

In the plan generation phase, each federated learning task is associated with a federal learning plan. This plan is generated by a combination of models and configurations. Each plan is divided into two parts, the server, and the device. The author designed the library to help the engineer. Automatically separate the two parts.

In the specific deployment phase, the federated learning plan can be deployed to the server only if certain conditions are met, and the version issue is also a challenge in federated learning. To overcome this challenge, the tool can help developers generate versions. The plan is mainly compatible with other TensorFlow versions by modifying the calculation map.

Finally, in the evaluation phase, as described earlier, each node will record some additional information to assist in the analysis, and engineers can use the provided analysis tools to analyze the data.

Communication efficiency

Based on the limitations of each device, communication efficiency may become a bottleneck for federated learning. For the problem of model efficiency, there are some solutions to improve communication efficiency. There are two main ways to improve communication efficiency, which are all implemented by modifying the transmission update. The first way is to transmit a structured update, and the second way is to transmit a summary update.


What Google researchers do in this project is to describe the main components of the system and the challenges they face, to determine which issues are not resolved, and hope that these efforts will be instructive for further systematic research.

🔗Paper Link