联邦学习
摘要: 本文介绍了联邦学习的概念与各种算法以及具体实现。
FedAvg算法
联邦学习原始论文中给出的FedAvg的算法框架为:
参数介绍:\(K\)表示客户端的个数,\(B\) 表示Batch_size,\(E\)表示本地更新的次数,\(\eta\)表示学习率。 \(P_k\)表示该客户端所有的数据。另外一个B表示的是由切分后的数据Batch组成的集合,\(C\)是比例,当\(C==1\)算法变为FedSGD。
从\(n\)个客户端选出随机选出\(m\)个客户端。
\(S_t\)为客户端的随机序列。
\(S_t\)中每个客户端并行进行梯度下降。
上传模型,中心服务器进行平均。
注意,我们发现
全局模型的生成这一步对K个模型都进行了求和,但是实际上只抽样了m个客户端,我认为这种写法还是错误的。参考了书本和很多博客,这一步想要表达的意思就是说,随机选择m个客户端采样,对这m个客户端的梯度更新进行平均以形成全局更新,同时用当前全局模型替换未采样的客户端。\(n_k\)为k的样本数量,\(n\)是样本总量,假设样本均衡,那就是平均。
为什么是上传模型而不是梯度?
FedAvg解决的FL主要瓶颈 :
1. 通信速率不稳定,且可能不可靠
2. 聚合服务器的容量有限,同时与server通信的client的数量受限
因此将模型在本地多训练几个Epoch再上传,一直上传梯度对于中心服务器这种高频度的通信可能是无法承受的。
Fl Minst代码实现
- 初始化
1
2
3
4
5
6
7
8
9
10
11
12import random
net = Net().to(device)
cilent_num=15
client_nets=[Net().to(device) for i in range(cilent_num)] # 建立客户端网络
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
optimizers = [optim.SGD(client_nets[index].parameters(), lr=0.01, momentum=0.9) for index in range(cilent_num)]# 为每个客户端建立单独的优化器
# 下载模型参数到客户端
def download():
for model in client_nets:
model.load_state_dict(net.state_dict())
download() - 联邦学习
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
client_index=[]
for i in range(cilent_num):
if random.randint(0,100)<50:
client_index.append(i) # 随机选择客户端
for epoch in range(num_epochs):
index=0
for i, data in enumerate(train_loader):
if index in client_index:
images, labels = data
images = images.to(device)
labels = labels.to(device)
outputs = client_nets[index](images)
loss = criterion(outputs, labels)
loss.backward() # 求梯度,梯度值存在net中。
optimizers[index].step()
optimizers[index].zero_grad()
index+=1
if index%15==0:
index=0
## FedAvg
first=True
weight=None
import copy
for index in client_index:
if first:
weight=copy.deepcopy(client_nets[index]).state_dict()
for key,value in weight.items():
weight[key]=weight[key]/len(client_index)
first=False
else:
new_weight=copy.deepcopy(client_nets[index].state_dict())
for key,value in weight.items():
weight[key]=weight[key]+new_weight[key]/len(client_index)#先除以总项目数再加
net.load_state_dict(weight)
# 更新客户端
download() - 两种数据毒害手段
1
2
3
4
5
6
7
8
9
10
11
12def ShuffleLabel(t):
t_list=[int(i) for i in t]
random.shuffle(t_list)
res=torch.LongTensor(t_list)
return res
# Attack 10%
def RandomLabel(t):
t_list=[random.randint(0,0) for i in t]
random.shuffle(t_list)
res=torch.LongTensor(t_list)
return res
# Attack 15%
1 | adverse_index=[2,6,8] #建立一个恶意用户名单。 |
联邦学习防御系统
- 基于可信度排名的中心模型聚合机制
- 基于模型特征聚类与动态验证的可信度计算机制
- 基于扰动的在线防御机制
All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.