本發(fā)明屬于聯(lián)邦學(xué)習(xí)的,更具體地,涉及基于擴(kuò)散模型的自適應(yīng)雙邊蒸餾個性化聯(lián)邦學(xué)習(xí)方法。
背景技術(shù):
1、聯(lián)邦學(xué)習(xí)是一種分布式機(jī)器學(xué)習(xí)方法,其最初的目的是為了解決“數(shù)據(jù)孤島”問題。然而,隨著技術(shù)的不斷迭代和用戶數(shù)量的急劇增加,一個單一的泛化全局模型已經(jīng)難以滿足用戶對個性化的需求。為了解決這一問題,個性化聯(lián)邦學(xué)習(xí)應(yīng)運(yùn)而生。個性化聯(lián)邦學(xué)習(xí)在共享全局知識的同時,針對每個用戶的特定數(shù)據(jù)分布和任務(wù)需求優(yōu)化本地模型的性能。相比傳統(tǒng)聯(lián)邦學(xué)習(xí),它不僅能夠緩解數(shù)據(jù)異構(gòu)問題,還能在犧牲極少全局泛化能力的前提下,顯著提升本地模型在特定客戶端上的表現(xiàn)。
2、作為一種高效的知識遷移手段,知識蒸餾被廣泛應(yīng)用于個性化聯(lián)邦學(xué)習(xí)中。然而,由于隱私限制,服務(wù)器無法直接訪問客戶端的本地數(shù)據(jù),這使得生成偽數(shù)據(jù)成為一種有效的解決方案。當(dāng)前較為流行的偽數(shù)據(jù)生成技術(shù)是生成對抗網(wǎng)絡(luò)gan。然而,gan需要生成器generator和判別器discriminator之間的對抗訓(xùn)練,在此過程中可能出現(xiàn)“模式崩潰”現(xiàn)象,即生成的樣本缺乏多樣性,僅覆蓋少量相似的模式。為克服這一問題,條件擴(kuò)散模型conditional?diffusion?models被提出。該方法通過逐步向原始數(shù)據(jù)添加噪聲并學(xué)習(xí)去噪過程生成偽數(shù)據(jù),從根本上解決了gan的模式崩潰問題,生成的樣本質(zhì)量和多樣性得到了顯著提升。
3、現(xiàn)有的大多數(shù)個性化聯(lián)邦學(xué)習(xí)方法往往只關(guān)注模型個性化和泛化能力中的某一個方面,導(dǎo)致兩者失衡,甚至可能遺忘先前學(xué)到的知識。例如,feddistill通過客戶端和服務(wù)器之間傳遞logits實(shí)現(xiàn)信息交換,這雖然提升了全局模型的泛化性能,但忽視了對客戶端個性化需求的支持,且logits的傳遞可能帶來隱私泄露風(fēng)險。fedamp則通過計算客戶端間的模型相似性,實(shí)現(xiàn)自適應(yīng)的模型協(xié)作聚合。然而,它的重點(diǎn)依然放在全局模型的優(yōu)化上,缺乏對客戶端個性化的有效支持。fml通過本地模型和全局模型的相互蒸餾,兼顧了個性化知識和全局知識的學(xué)習(xí),但全局模型在利用本地數(shù)據(jù)訓(xùn)練時可能遺忘部分全局知識,導(dǎo)致泛化性能下降。fedproto假設(shè)客戶端之間存在共享的類原型,但在客戶端類分布完全不同或無重疊的情況下,全局原型可能無法準(zhǔn)確表達(dá)全局特征,導(dǎo)致其泛化能力受限。
4、為了解決上述問題,近年來提出了一些改進(jìn)方案。例如,fedkd通過本地模型與全局模型的自適應(yīng)相互蒸餾,實(shí)現(xiàn)了個性化與泛化性能的平衡。然而,它仍可能遺忘部分全局知識,從而影響整體效果。因此,迫切需要一種方法,在盡可能保證模型泛化性能和避免遺忘知識的前提下,實(shí)現(xiàn)更高效的個性化。
5、中國專利文獻(xiàn)cn117035057a公開一種基于模型和數(shù)據(jù)蒸餾的個性化聯(lián)邦學(xué)習(xí)方法,該方法步驟如下:在客戶端構(gòu)建本地模型,包含共享編碼器和私有解碼器??蛻舳烁鶕?jù)私有數(shù)據(jù)集訓(xùn)練本地模型,向服務(wù)器上傳共享編碼器的模型參數(shù)。客戶端根據(jù)本地模型計算公有數(shù)據(jù)集的輸出logits并向服務(wù)器上傳logits。服務(wù)器對logits和多個客戶端的共享編碼器模型參數(shù)分別做加權(quán)聚合,得到全局logits和多個全局編碼器模型參數(shù)。各客戶端下載多個全局編碼器模型,更新客戶端中的多個本地共享編碼器模型參數(shù),下載全局logits,以知識蒸餾的方式參與解碼器的訓(xùn)練。但是該方法存在以下不足:1、該方法依賴全局編碼器和全局logits進(jìn)行知識蒸餾,但全局模型在捕捉客戶端個性化特征方面存在局限性,特別是在數(shù)據(jù)分布非獨(dú)立同分布的情況下,難以滿足個性化需求。2、該方法未充分考慮全局模型在本地更新過程中可能出現(xiàn)的“災(zāi)難性遺忘”問題,即對先前知識的丟失。
6、有鑒于此,本發(fā)明設(shè)計一種基于擴(kuò)散模型的自適應(yīng)雙邊蒸餾個性化聯(lián)邦學(xué)習(xí)方法。
技術(shù)實(shí)現(xiàn)思路
1、本發(fā)明旨在克服上述現(xiàn)有技術(shù)的至少一種缺陷,提供基于擴(kuò)散模型的自適應(yīng)雙邊蒸餾個性化聯(lián)邦學(xué)習(xí)方法,通過設(shè)計一種指導(dǎo)機(jī)制,使全局模型與本地模型之間進(jìn)行相互蒸餾,從而實(shí)現(xiàn)知識的高效傳遞,增強(qiáng)個性化模型對客戶端特定數(shù)據(jù)分布的適應(yīng)性。同時,引入條件擴(kuò)散模型生成高質(zhì)量的偽數(shù)據(jù),并利用這些偽數(shù)據(jù)對聚合后的全局模型進(jìn)行微調(diào)。該過程不僅有效彌補(bǔ)了局部-全局相互蒸餾過程中可能丟失的全局信息,還進(jìn)一步優(yōu)化了全局模型的表現(xiàn)。通過結(jié)合相互蒸餾和條件擴(kuò)散微調(diào)技術(shù),本發(fā)明在保護(hù)數(shù)據(jù)隱私的同時,實(shí)現(xiàn)了個性化性能與全局泛化能力的平衡,適用于非獨(dú)立同分布non-iid數(shù)據(jù)環(huán)境下的多客戶端協(xié)作場景。
2、本發(fā)明詳細(xì)的技術(shù)方案如下:
3、基于擴(kuò)散模型的自適應(yīng)雙邊蒸餾個性化聯(lián)邦學(xué)習(xí)方法,所述方法包括:
4、s1、服務(wù)器初始化全局模型,將初始的全局模型廣播發(fā)送給各個參與的客戶端;
5、s2、客戶端接收服務(wù)器傳來的全局模型,客戶端利用本地數(shù)據(jù)對接收到的全局模型和本地的局部模型進(jìn)行訓(xùn)練,得到局部損失和全局損失,然后,利用本地數(shù)據(jù)對全局模型和本地的局部模型進(jìn)行相互蒸餾,并利用局部損失和全局損失來對蒸餾過程進(jìn)行指導(dǎo),得到本地全局模型;同時,客戶端利用本地數(shù)據(jù)和類別向量來訓(xùn)練本地條件擴(kuò)散模型得到本地的局部生成器;
6、s3、客戶端將調(diào)整后的本地全局模型和本地的局部生成器發(fā)生至服務(wù)器;
7、s4、服務(wù)器根據(jù)客戶端的本地數(shù)據(jù)量將收到的本地全局模型與初始全局模型進(jìn)行聚合,得到聚合后的全局模型;然后利用kl散度來計算本地全局模型和聚合后的全局模型的相似度;
8、s5、利用相似度來對接收到的本地的局部生成器進(jìn)行加權(quán)聚合,得到一個能夠產(chǎn)生全局偽數(shù)據(jù)的全局生成器;
9、s6、服務(wù)器利用全局生成器生成的全局偽數(shù)據(jù)對聚合后的全局模型和歷史全局模型進(jìn)行知識蒸餾,進(jìn)一步優(yōu)化全局模型;
10、s7、服務(wù)器將優(yōu)化后的全局模型重新廣播發(fā)送給各個參與的客戶端,重復(fù)上述步驟s1-s6,直到達(dá)到預(yù)設(shè)的輪次后結(jié)束,得到最終微調(diào)后的全局模型。
11、根據(jù)本發(fā)明優(yōu)選的,步驟s2中,所述客戶端利用本地數(shù)據(jù)對接收到的全局模型和本地的局部模型進(jìn)行訓(xùn)練,得到局部損失和全局損失具體如下:
12、將本地數(shù)據(jù)分別送入到本地的局部模型和全局模型進(jìn)行訓(xùn)練,在過程使用交叉熵?fù)p失函數(shù)來計算它們各自的損失:
13、(1)
14、(2)
15、式(1)和(2)中,表示局部損失,表示全局損失,是局部模型預(yù)測的結(jié)果,是全局模型預(yù)測的結(jié)果,是真實(shí)的標(biāo)簽,代表客戶端數(shù)據(jù)集的類別總數(shù)。
16、根據(jù)本發(fā)明優(yōu)選的,步驟s2中,所述利用本地數(shù)據(jù)對全局模型和本地的局部模型進(jìn)行相互蒸餾,并利用局部損失和全局損失來對蒸餾過程進(jìn)行指導(dǎo),得到本地全局模型具體如下:
17、根據(jù)本地的局部模型和全局模型在數(shù)據(jù)集上預(yù)測的準(zhǔn)確性來控制相互蒸餾的強(qiáng)度,若準(zhǔn)確性越高,則損失越小,這時蒸餾強(qiáng)度就越小,若準(zhǔn)確性低,則損失越大,這時蒸餾強(qiáng)度就越大,需要學(xué)習(xí)更多的知識;
18、在相互蒸餾的過程中,使用了kl散度來控制本地的局部模型預(yù)測分布和全局模型預(yù)測分布之間的差異,通過最小化二者的差異來達(dá)到知識轉(zhuǎn)移的目的,這個過程如下所示:
19、(3)
20、(4)
21、式(3)和(4)中,表示本地模型相對于全局模型的差異,表示全局模型相對于本地模型的差異,是局部模型預(yù)測的結(jié)果,是全局模型預(yù)測的結(jié)果,表示局部損失,表示全局損失,分母通過兩個交叉熵?fù)p失的和實(shí)現(xiàn)了控制知識轉(zhuǎn)移強(qiáng)度的目的。
22、根據(jù)本發(fā)明優(yōu)選的,步驟s2中,所述客戶端利用本地數(shù)據(jù)和類別向量來訓(xùn)練本地條件擴(kuò)散模型得到本地的局部生成器具體如下:
23、首先,利用條件擴(kuò)散模型,進(jìn)行前向擴(kuò)散:將本地數(shù)據(jù)作為目標(biāo)數(shù)據(jù),通過逐漸向原始數(shù)據(jù)中添加噪聲,使得原始數(shù)據(jù)逐漸接近于純噪聲,公式如下:
24、(5)
25、式(5)中,表示在擴(kuò)散模型中從原始數(shù)據(jù)到的條件概率分布,代表原始數(shù)據(jù),代表對原始數(shù)據(jù)經(jīng)過加噪時間步t之后的數(shù)據(jù),表示高斯分布,用于控制噪聲的大小,表示累積噪聲衰減系數(shù),i表示單位矩陣;通過該過程,原始數(shù)據(jù)逐步變?yōu)榧冊肼暎?/p>
26、然后,進(jìn)行逆向擴(kuò)散,在逆向擴(kuò)散過程中,引入了類別嵌入向量c,用于控制生成符合本地數(shù)據(jù)集分布的數(shù)據(jù);隨后,再利用前向擴(kuò)散過程中學(xué)到的加入的噪聲的方差和均值來逐步去噪達(dá)到還原原始數(shù)據(jù)的目的,公式表示如下:
27、(6)
28、(7)
29、在式(6)和(7)中,表示從逆向生成時間t-1狀態(tài)樣本的條件概率分布,c是類別嵌入向量,是噪聲的均值,為噪聲的方差,t代表時間步,代表從條件概率分布中采樣得到時間狀態(tài)t-1時刻的樣本;經(jīng)過逆向擴(kuò)散過程逐步將數(shù)據(jù)還原為符合目標(biāo)分布的目標(biāo)數(shù)據(jù);
30、逆向優(yōu)化過程的優(yōu)化目標(biāo)為:
31、(8)
32、式(8)中,表示均方誤差損失函數(shù),是標(biāo)準(zhǔn)的高斯噪聲,是模型預(yù)測的噪聲,表示對時間步t、原始數(shù)據(jù)和噪聲的期望值;條件擴(kuò)散模型通過最小化模型預(yù)測的噪聲和真實(shí)添加的噪聲之間的距離,從而使逆向生成過程更加準(zhǔn)確;通過該優(yōu)化過程條件擴(kuò)散模型能夠逐步學(xué)習(xí)逆向擴(kuò)散過程;
33、至此,局部生成器訓(xùn)練過程結(jié)束,得到了用于產(chǎn)生符合客戶端數(shù)據(jù)分布的局部生成器。
34、根據(jù)本發(fā)明優(yōu)選的,步驟s4具體如下:
35、服務(wù)器對全局模型根據(jù)其本地的數(shù)據(jù)量進(jìn)行全局聚合,公式如下:
36、(9)
37、在式(9)中,是第輪聚合后的全局模型,n代表參與聚合的客戶端總數(shù),代表客戶端j的本地數(shù)據(jù)量,|d|代表所有客戶端的數(shù)據(jù)總量,代表第i輪的客戶端j的全局模型;
38、聚合完成后,服務(wù)器需要計算各個本地的全局模型和聚合后的全局模型之間的相似度,由于這里的模型參數(shù)都是高維的向量,因此需要使用余弦相似度,其計算過程如下所示:
39、(10)
40、在式(10)中,是計算本地的全局模型和聚合后全局模型的初步相似度,代表第i輪的客戶端k的全局模型參數(shù)向量,代表第i輪聚合后的全局模型參數(shù)向量,代表l2范數(shù),和表示將和歸一化為單位向量;隨后將進(jìn)行標(biāo)準(zhǔn)化處理,如公式(11)所示:
41、(11)
42、在式(11)中,是第輪客戶端的本地全局模型和聚合后的全局模型的相似度,滿足。
43、根據(jù)本發(fā)明優(yōu)選的,步驟s5具體如下:
44、根據(jù)計算得到的相似度,對各個局部生成器進(jìn)行加權(quán)聚合,其計算公式如(12)所示:
45、(12)
46、在式(12)中,代表第i輪聚合后的全局生成器,代表第輪客戶端的本地全局模型和聚合后的全局模型的相似度,代表第i輪客戶端k的本地生成器。
47、根據(jù)本發(fā)明優(yōu)選的,由于聚合后的全局模型經(jīng)過本地的訓(xùn)練后可能會遺忘之前已經(jīng)學(xué)過的知識,因此,為了應(yīng)對這種情況,考慮利用全局生成器生成的偽數(shù)據(jù)和前輪保存的全局模型對聚合后的全局模型進(jìn)行知識蒸餾微調(diào),這就從根本上避免了全局模型的“災(zāi)難性遺忘”的問題。
48、首先利用全局生成器生成全局偽數(shù)據(jù),然后選擇前輪的全局模型進(jìn)行加權(quán)聚合,得到聚合后的前輪的全局模型參數(shù),具體公式如(13)所示:
49、(13)
50、在式(13)中,代表前m輪全局模型參數(shù)的平均值,代表第i-j輪的全局模型參數(shù)向量,m表示選擇前m輪的全局模型參數(shù),為了避免不必要的計算成本,這里的,其中,代表客戶端和服務(wù)器的通信輪數(shù);
51、隨后,再根據(jù)得到的和對進(jìn)行知識蒸餾:首先,計算蒸餾過程所使用的學(xué)生和教師模型的概率分布:
52、(14)
53、(15)
54、式(14)和(15)中,是經(jīng)過平滑處理后的教師模型的概率分布,是經(jīng)過平滑處理后的學(xué)生模型的概率分布,代表教師模型在全局偽數(shù)據(jù)上的輸出logits,代表學(xué)生模型在全局偽數(shù)據(jù)上的輸出logits,代表蒸餾溫度;
55、接下來,計算學(xué)生模型的概率分布和教師模型的概率分布之間的kl散度:
56、(16)
57、式(16)中,用于抵消概率分布平滑帶來的梯度縮放;
58、通過最小化kl散度,來更新,從而達(dá)到微調(diào)的效果,其公式如(17)所示:
59、(17)
60、在式(17)中,表示學(xué)習(xí)率,用于控制每次更新的步長大小,表示損失函數(shù)的梯度;
61、最后,將經(jīng)過微調(diào)后的全局模型參數(shù)廣播分發(fā)給各個客戶端,用其更新本地的全局模型參數(shù),不斷地迭代以上過程,直到全局模型收斂。
62、與現(xiàn)有技術(shù)相比,本發(fā)明的有益效果為:
63、(1)本發(fā)明通過局部模型與全局模型的自適應(yīng)相互蒸餾,充分利用局部損失和全局損失,引導(dǎo)全局模型在泛化性能和個性化需求之間實(shí)現(xiàn)動態(tài)平衡。
64、(2)本發(fā)明利用全局生成器生成的偽數(shù)據(jù)對聚合后的全局模型進(jìn)行微調(diào),有效補(bǔ)償在蒸餾過程中可能丟失的全局知識,避免全局模型因新知識引入而遺忘歷史知識。
65、(3)由于各個客戶端的數(shù)據(jù)量、數(shù)據(jù)分布等都會存在差異,本發(fā)明通過為各個客戶端學(xué)習(xí)個性化的模型從而避免了數(shù)據(jù)異構(gòu)帶來的影響。針對聯(lián)邦學(xué)習(xí)中常見的非獨(dú)立同分布數(shù)據(jù),本發(fā)明的相互蒸餾機(jī)制和條件擴(kuò)散模型的引入能夠更好地適應(yīng)和處理各客戶端數(shù)據(jù)的異構(gòu)性。
66、(4)本發(fā)明使用全局偽數(shù)據(jù)進(jìn)行全局知識蒸餾,避免了直接使用私有數(shù)據(jù)帶來的隱私泄露的風(fēng)險。