本發(fā)明涉及人工智能,具體地,涉及一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)方法、系統(tǒng)及介質(zhì)。
背景技術(shù):
1、聯(lián)邦學(xué)習(xí)作為一種從分散異構(gòu)數(shù)據(jù)中進(jìn)行學(xué)習(xí)的方法,目前已廣受關(guān)注。然而,聯(lián)邦學(xué)習(xí)面臨多重挑戰(zhàn),其中主要挑戰(zhàn)有兩點(diǎn):(1)客戶端數(shù)據(jù)分布呈現(xiàn)非獨(dú)立同分布(non-iid)特性;(2)由于獲取準(zhǔn)確標(biāo)注數(shù)據(jù)所需成本較高,客戶端數(shù)據(jù)的標(biāo)簽具有稀缺性。這些因素可能導(dǎo)致聯(lián)邦學(xué)習(xí)性能下降及訓(xùn)練收斂遭遇困難。現(xiàn)有的研究雖通過多種方法來應(yīng)對(duì)這些挑戰(zhàn),如緩解客戶端漂移、優(yōu)化服務(wù)器端模型聚合、利用輔助公共數(shù)據(jù)集進(jìn)行數(shù)據(jù)增強(qiáng)、以及開展聯(lián)邦半監(jiān)督學(xué)習(xí)等。盡管這些方法取得了一定成效,但它們并未從根本上解決聯(lián)邦學(xué)習(xí)性能下降的問題,即分散客戶端的有標(biāo)簽數(shù)據(jù)分布與全局有標(biāo)簽數(shù)據(jù)分布之間存在不匹配。
2、以擴(kuò)散模型為代表的生成模型展現(xiàn)了生成高質(zhì)量、可控?cái)?shù)據(jù)樣本的能力,為緩解客戶端數(shù)據(jù)分布與全局?jǐn)?shù)據(jù)分布差異提供了新的可能性?,F(xiàn)有的研究大多采用預(yù)訓(xùn)練生成模型來生成數(shù)據(jù)樣本,以對(duì)抗數(shù)據(jù)分布non-iid的影響。然而,這種方法存在兩個(gè)主要限制:首先,訓(xùn)練適用于特定場景的預(yù)訓(xùn)練生成模型需要耗費(fèi)大量資源;其次,這種方法未能充分利用客戶端上的未標(biāo)注數(shù)據(jù)。
3、現(xiàn)有技術(shù)文獻(xiàn)cn116229172a公開了一種基于對(duì)比學(xué)習(xí)的聯(lián)邦少樣本圖像分類模型訓(xùn)練方法、分類方法及設(shè)備,該文獻(xiàn)包括:構(gòu)建訓(xùn)練集和查詢集,為訓(xùn)練集添加真實(shí)標(biāo)簽;獲取初始模型,該初始模型包括嵌入網(wǎng)絡(luò)和關(guān)系網(wǎng)絡(luò);將訓(xùn)練集樣本和查詢集樣本成對(duì)輸入嵌入網(wǎng)絡(luò),提取訓(xùn)練集樣本特征圖和查詢集樣本特征圖并進(jìn)行拼接,生成拼接特征圖;將拼接特征圖輸入關(guān)系網(wǎng)絡(luò)計(jì)算得到相似度分?jǐn)?shù),以得到訓(xùn)練集樣本的類別;采用本地?cái)?shù)據(jù)集對(duì)初始模型進(jìn)行訓(xùn)練,并構(gòu)建均方誤差損失,得到初始圖像分類模型;基于各客戶端模型參數(shù)構(gòu)建共享模型,并根據(jù)共享模型參數(shù),采用指數(shù)移動(dòng)平均更新初始圖像分類模型,得到最終的圖像分類模型。該現(xiàn)有技術(shù)文獻(xiàn)采用對(duì)比學(xué)習(xí),在處理非獨(dú)立同分布(non-iid)數(shù)據(jù)上存在不足。
4、因此,如何更有效地利用分散客戶端上的數(shù)據(jù)資源,同時(shí)克服數(shù)據(jù)non-iid和標(biāo)簽稀缺帶來的影響,仍然是一個(gè)具有挑戰(zhàn)性的問題,急需研究一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)系統(tǒng)及方法。
技術(shù)實(shí)現(xiàn)思路
1、針對(duì)現(xiàn)有技術(shù)中的缺陷,本發(fā)明旨在提供一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)輔助少樣本聯(lián)邦學(xué)習(xí)方法、系統(tǒng)及介質(zhì),以提升聯(lián)邦學(xué)習(xí)在復(fù)雜數(shù)據(jù)環(huán)境下的性能。
2、根據(jù)本發(fā)明提供的一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)方法,包括如下步驟:
3、步驟s1,邊緣服務(wù)器利用分散在每個(gè)客戶端上標(biāo)注數(shù)據(jù),進(jìn)行聯(lián)邦訓(xùn)練,構(gòu)建基礎(chǔ)分類器;
4、步驟s2,每個(gè)客戶端采用所述基礎(chǔ)分類器對(duì)未標(biāo)注數(shù)據(jù)進(jìn)行預(yù)測,得到每個(gè)未標(biāo)注數(shù)據(jù)的類別概率分布,根據(jù)預(yù)測的概率分布,為未標(biāo)注數(shù)據(jù)進(jìn)行偽標(biāo)簽分配,得到帶有偽標(biāo)簽的數(shù)據(jù);
5、步驟s3,將標(biāo)注數(shù)據(jù)和帶有偽標(biāo)簽的數(shù)據(jù)合并,進(jìn)行生成模型的聯(lián)邦訓(xùn)練,得到訓(xùn)練好的生成模型;
6、步驟s4,客戶端向邊緣服務(wù)器上傳本地標(biāo)注數(shù)據(jù)分布信息,邊緣服務(wù)器綜合所有客戶端的分布信息,得到全局標(biāo)注數(shù)據(jù)分布,并反饋給每個(gè)客戶端;
7、步驟s5,每個(gè)客戶端根據(jù)全局標(biāo)注數(shù)據(jù)分布,利用訓(xùn)練好的生成模型為每個(gè)類別生成合成數(shù)據(jù);
8、步驟s6,將標(biāo)注數(shù)據(jù)和合成數(shù)據(jù)合并,聯(lián)邦訓(xùn)練最終分類器。
9、優(yōu)選地,步驟s1中,每個(gè)客戶端在本地對(duì)基本分類器進(jìn)行訓(xùn)練,在每個(gè)輪次中,客戶端將基本分類器參數(shù)上傳至邊緣服務(wù)器,邊緣服務(wù)器采用特定的參數(shù)聚合算法得到全局模型參數(shù),然后邊緣服務(wù)器將全局模型參數(shù)下發(fā)至每個(gè)客戶端,客戶端接收全局模型參數(shù)后繼續(xù)本地訓(xùn)練,重復(fù)上述過程直到收斂。
10、優(yōu)選地,步驟s1中,基礎(chǔ)分類器采用殘差網(wǎng)絡(luò)。
11、優(yōu)選地,邊緣服務(wù)器采用聯(lián)邦平均算法得到全局模型參數(shù)。
12、優(yōu)選地,步驟s3中,生成模型采用條件擴(kuò)散模型。
13、本發(fā)明還提供了一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)系統(tǒng),包括:構(gòu)建基礎(chǔ)分類器模塊,邊緣服務(wù)器利用分散在每個(gè)客戶端上標(biāo)注數(shù)據(jù),進(jìn)行聯(lián)邦訓(xùn)練,構(gòu)建基礎(chǔ)分類器;
14、偽標(biāo)簽生成模塊,每個(gè)客戶端運(yùn)用基礎(chǔ)分類器對(duì)未標(biāo)注數(shù)據(jù)進(jìn)行偽標(biāo)簽分配,得到帶有偽標(biāo)簽的數(shù)據(jù);
15、生成模型訓(xùn)練模塊,將標(biāo)注數(shù)據(jù)和帶有偽標(biāo)簽的數(shù)據(jù)合并,進(jìn)行生成模型的聯(lián)邦訓(xùn)練,得到訓(xùn)練好的生成模型;
16、分布信息交換模塊,客戶端向邊緣服務(wù)器上傳本地標(biāo)注數(shù)據(jù)分布信息,邊緣服務(wù)器綜合所有客戶端的分布信息,得到全局標(biāo)注數(shù)據(jù)分布,并反饋給每個(gè)客戶端;
17、合成數(shù)據(jù)生成模塊,每個(gè)客戶端根據(jù)全局標(biāo)注數(shù)據(jù)分布,利用訓(xùn)練好的生成模型為每個(gè)類別生成合成數(shù)據(jù);
18、最終分類器訓(xùn)練模塊,將標(biāo)注數(shù)據(jù)和合成數(shù)據(jù)合并,聯(lián)邦訓(xùn)練最終分類器。
19、優(yōu)選地,系統(tǒng)包括一個(gè)邊緣服務(wù)器和多個(gè)客戶端,每個(gè)客戶端均設(shè)有分類器模型和生成模型,且每個(gè)客戶端上包含有標(biāo)簽數(shù)據(jù)和無標(biāo)簽數(shù)據(jù)。
20、優(yōu)選地,分類器模型采用殘差網(wǎng)絡(luò)。
21、優(yōu)選地,生成模型采用條件擴(kuò)散模型,條件擴(kuò)散模型根據(jù)設(shè)定的標(biāo)簽控制生成的數(shù)據(jù)類別。
22、本發(fā)明還提供了一種存儲(chǔ)有計(jì)算機(jī)程序的計(jì)算機(jī)可讀存儲(chǔ)介質(zhì),計(jì)算機(jī)程序被處理器執(zhí)行時(shí)實(shí)現(xiàn)上述的一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)方法的步驟。
23、與現(xiàn)有技術(shù)相比,本發(fā)明具有如下的有益效果:
24、本發(fā)明基于機(jī)器學(xué)習(xí),充分利用分散客戶端上的數(shù)據(jù)資源,借助生成模型彌合本地?cái)?shù)據(jù)分布和全局分布之間的差異,即使在復(fù)雜的數(shù)據(jù)分布場景下,本方法仍具有較好的表現(xiàn),克服了少樣本和數(shù)據(jù)分布不平衡的負(fù)面影響,同時(shí)降低現(xiàn)有方法中由于數(shù)據(jù)標(biāo)注和預(yù)訓(xùn)練模型所帶來的資源成本。
1.一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)方法,其特征在于,包括如下步驟:
2.根據(jù)權(quán)利要求1所述的一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)方法,其特征在于,所述步驟s1中,每個(gè)客戶端在本地對(duì)所述基本分類器進(jìn)行訓(xùn)練,在每個(gè)輪次中,客戶端將基本分類器參數(shù)上傳至邊緣服務(wù)器,所述邊緣服務(wù)器采用特定的參數(shù)聚合算法得到全局模型參數(shù),然后所述邊緣服務(wù)器將所述全局模型參數(shù)下發(fā)至每個(gè)客戶端,客戶端接收全局模型參數(shù)后繼續(xù)本地訓(xùn)練,重復(fù)上述過程直到收斂。
3.根據(jù)權(quán)利要求1所述的一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)方法,其特征在于,所述步驟s1中,所述基礎(chǔ)分類器采用殘差網(wǎng)絡(luò)。
4.根據(jù)權(quán)利要求2所述的一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)方法,其特征在于,所述邊緣服務(wù)器采用聯(lián)邦平均算法得到全局模型參數(shù)。
5.根據(jù)權(quán)利要求1所述的一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)方法,其特征在于,所述步驟s3中,所述生成模型采用條件擴(kuò)散模型。
6.一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)系統(tǒng),其特征在于,包括:
7.根據(jù)權(quán)利要求6所述的一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)系統(tǒng),其特征在于,所述系統(tǒng)包括一個(gè)邊緣服務(wù)器和多個(gè)客戶端,每個(gè)客戶端均設(shè)有分類器模型和生成模型,且每個(gè)客戶端上包含有標(biāo)簽數(shù)據(jù)和無標(biāo)簽數(shù)據(jù)。
8.根據(jù)權(quán)利要求7所述的一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)系統(tǒng),其特征在于,所述分類器模型采用殘差網(wǎng)絡(luò)。
9.根據(jù)權(quán)利要求7所述的一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)系統(tǒng),其特征在于,所述生成模型采用條件擴(kuò)散模型,所述條件擴(kuò)散模型根據(jù)設(shè)定的標(biāo)簽控制生成的數(shù)據(jù)類別。
10.一種存儲(chǔ)有計(jì)算機(jī)程序的計(jì)算機(jī)可讀存儲(chǔ)介質(zhì),其特征在于,所述計(jì)算機(jī)程序被處理器執(zhí)行時(shí)實(shí)現(xiàn)權(quán)利要求1至5中任一項(xiàng)所述的一種基于生成模型的輔助少樣本聯(lián)邦學(xué)習(xí)方法的步驟。