摘要: 本文介绍了联邦学习的概念与各种算法以及具体实现。

FedAvg算法

联邦学习原始论文中给出的FedAvg的算法框架为:

在这里插入图片描述

参数介绍:\(K\)表示客户端的个数,\(B\) 表示Batch_size\(E\)表示本地更新的次数,\(\eta\)表示学习率。 \(P_k\)表示该客户端所有的数据。另外一个B表示的是由切分后的数据Batch组成的集合,\(C\)是比例,当\(C==1\)算法变为FedSGD

  1. \(n\)个客户端选出随机选出\(m\)个客户端。

  2. \(S_t\)为客户端的随机序列。

  3. \(S_t\)中每个客户端并行进行梯度下降。

  4. 上传模型,中心服务器进行平均。

  5. 注意,我们发现

    image-20221202081253192

全局模型的生成这一步对K个模型都进行了求和,但是实际上只抽样了m个客户端,我认为这种写法还是错误的。参考了书本和很多博客,这一步想要表达的意思就是说,随机选择m个客户端采样,对这m个客户端的梯度更新进行平均以形成全局更新,同时用当前全局模型替换未采样的客户端。\(n_k\)为k的样本数量,\(n\)是样本总量,假设样本均衡,那就是平均。

为什么是上传模型而不是梯度?

FedAvg解决的FL主要瓶颈 :
1. 通信速率不稳定,且可能不可靠
2. 聚合服务器的容量有限,同时与server通信的client的数量受限

因此将模型在本地多训练几个Epoch再上传,一直上传梯度对于中心服务器这种高频度的通信可能是无法承受的。

Fl Minst代码实现

  • 初始化
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    import random
    net = Net().to(device)
    cilent_num=15
    client_nets=[Net().to(device) for i in range(cilent_num)] # 建立客户端网络
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
    optimizers = [optim.SGD(client_nets[index].parameters(), lr=0.01, momentum=0.9) for index in range(cilent_num)]# 为每个客户端建立单独的优化器
    # 下载模型参数到客户端
    def download():
    for model in client_nets:
    model.load_state_dict(net.state_dict())
    download()
  • 联邦学习
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38

    client_index=[]
    for i in range(cilent_num):
    if random.randint(0,100)<50:
    client_index.append(i) # 随机选择客户端

    for epoch in range(num_epochs):
    index=0
    for i, data in enumerate(train_loader):
    if index in client_index:
    images, labels = data
    images = images.to(device)
    labels = labels.to(device)
    outputs = client_nets[index](images)
    loss = criterion(outputs, labels)
    loss.backward() # 求梯度,梯度值存在net中。
    optimizers[index].step()
    optimizers[index].zero_grad()
    index+=1
    if index%15==0:
    index=0
    ## FedAvg
    first=True
    weight=None
    import copy
    for index in client_index:
    if first:
    weight=copy.deepcopy(client_nets[index]).state_dict()
    for key,value in weight.items():
    weight[key]=weight[key]/len(client_index)
    first=False
    else:
    new_weight=copy.deepcopy(client_nets[index].state_dict())
    for key,value in weight.items():
    weight[key]=weight[key]+new_weight[key]/len(client_index)#先除以总项目数再加
    net.load_state_dict(weight)
    # 更新客户端
    download()
  • 两种数据毒害手段
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    def ShuffleLabel(t):  
    t_list=[int(i) for i in t]
    random.shuffle(t_list)
    res=torch.LongTensor(t_list)
    return res
    # Attack 10%
    def RandomLabel(t):
    t_list=[random.randint(0,0) for i in t]
    random.shuffle(t_list)
    res=torch.LongTensor(t_list)
    return res
    # Attack 15%
1
2
3
adverse_index=[2,6,8] #建立一个恶意用户名单。
if index in adverse_index:
labels=ShuffleLabel(labels)

联邦学习防御系统

  • 基于可信度排名的中心模型聚合机制
  • 基于模型特征聚类与动态验证的可信度计算机制
  • 基于扰动的在线防御机制