本發(fā)明屬于計(jì)算機(jī)視覺(jué)領(lǐng)域,特別涉及一種基于剪枝和知識(shí)蒸餾組合的detr目標(biāo)檢測(cè)模型壓縮方法。
背景技術(shù):
1、detr是2020年提出的一種新的目標(biāo)檢測(cè)范式,其屬于無(wú)錨框檢測(cè)方法中的端到端的檢測(cè)方法。detr的模型結(jié)構(gòu)非常簡(jiǎn)單,包含三個(gè)主要組件:骨干網(wǎng)絡(luò)、transformer編碼器-解碼器特征融合網(wǎng)絡(luò)、前饋網(wǎng)絡(luò),但是該網(wǎng)絡(luò)模型計(jì)算成本高,內(nèi)存密集,阻礙了它在低內(nèi)存資源的設(shè)備或具有嚴(yán)格延遲要求的應(yīng)用程序中的部署。因此,就產(chǎn)生了在不顯著降低模型性能的情況下在深度網(wǎng)絡(luò)中執(zhí)行模型壓縮和加速的樸素想法。一般來(lái)說(shuō),常見(jiàn)的技術(shù)主要是網(wǎng)絡(luò)剪枝和知識(shí)蒸餾。
2、網(wǎng)絡(luò)剪枝可以降低網(wǎng)絡(luò)復(fù)雜性和解決過(guò)擬合問(wèn)題,網(wǎng)絡(luò)剪枝以保持原始精度的方式減少神經(jīng)網(wǎng)絡(luò)所需的存儲(chǔ)空間。常用的壓縮框架是一個(gè)三階段的管道:網(wǎng)絡(luò)剪枝、參數(shù)量化和霍夫曼編碼,通常應(yīng)用一個(gè)壓縮框架或者同時(shí)應(yīng)用多個(gè)壓縮框架,在不影響神經(jīng)網(wǎng)絡(luò)準(zhǔn)確性的情況下,將神經(jīng)網(wǎng)絡(luò)的存儲(chǔ)需求減少35到49倍。但這種剪枝框架在帶有transformer結(jié)構(gòu)的模型中無(wú)法有效使用,因此需要一種針對(duì)transformer結(jié)構(gòu)的剪枝方法來(lái)減少其巨大的計(jì)算開(kāi)銷(xiāo),促進(jìn)模型在邊緣設(shè)備中的部署和應(yīng)用。
3、知識(shí)蒸餾的核心思想是通過(guò)一個(gè)預(yù)先訓(xùn)練好的大模型(教師模型)來(lái)指導(dǎo)一個(gè)小模型(學(xué)生模型)的訓(xùn)練,這個(gè)過(guò)程就像教師在教學(xué)生一樣,通過(guò)將大型模型的性能轉(zhuǎn)移到小型模型,從而使小型模型在資源受限的環(huán)境中(如移動(dòng)設(shè)備和嵌入式系統(tǒng))也能表現(xiàn)出良好的性能,故稱(chēng)為教師-學(xué)生框架。其中,教師模型通常是一個(gè)大型的深度神經(jīng)網(wǎng)絡(luò),它在大量的訓(xùn)練數(shù)據(jù)上進(jìn)行訓(xùn)練,以達(dá)到較高的性能。然后,使用這個(gè)教師模型來(lái)生成軟標(biāo)簽(即模型的預(yù)測(cè)概率分布),并用這些軟標(biāo)簽來(lái)訓(xùn)練一個(gè)小的學(xué)生模型,進(jìn)而幫助學(xué)生模型學(xué)習(xí)到教師模型的一些隱含知識(shí),從而提高學(xué)生模型的性能。但是傳統(tǒng)知識(shí)蒸餾框架主要是針對(duì)分類(lèi)問(wèn)題,會(huì)對(duì)圖像的整體特征進(jìn)行蒸餾,而目標(biāo)檢測(cè)需要平衡前景和背景的特征差異,需要新的蒸餾方法。
4、綜上所述,基于網(wǎng)絡(luò)剪枝的模型壓縮方法不能在帶有transformer結(jié)構(gòu)的模型中有效使用,基于知識(shí)蒸餾的模型壓縮方法應(yīng)用在圖像中會(huì)對(duì)圖像特征進(jìn)行蒸餾,如何將網(wǎng)絡(luò)剪枝和知識(shí)蒸餾應(yīng)用在用于處理圖像且?guī)в衪ransformer結(jié)構(gòu)的模型中是本領(lǐng)域技術(shù)人員亟需解決的問(wèn)題。
技術(shù)實(shí)現(xiàn)思路
1、為了使detr目標(biāo)檢測(cè)模型在無(wú)明顯精度下滑的前提下,顯著降低模型的計(jì)算量并大幅提升模型的推理速度,本發(fā)明提出一種基于剪枝和知識(shí)蒸餾組合的detr目標(biāo)檢測(cè)模型壓縮方法,對(duì)待壓縮的目標(biāo)檢測(cè)模型中的transformer結(jié)構(gòu)進(jìn)行剪枝和蒸餾,具體包括以下步驟:
2、計(jì)算transformer結(jié)構(gòu)每一層的注意力機(jī)制,利用注意力機(jī)制構(gòu)建掩碼;
3、利用掩碼對(duì)transformer結(jié)構(gòu)線性層和注意力層的參數(shù)進(jìn)行自適應(yīng)剪枝,得到原模型的第一壓縮模型;
4、將待壓縮的transformer結(jié)構(gòu)的編碼器作為學(xué)生網(wǎng)絡(luò),原模型作為教師網(wǎng)絡(luò)進(jìn)行蒸餾學(xué)習(xí);
5、學(xué)生網(wǎng)絡(luò)根據(jù)全局蒸餾損失和交點(diǎn)蒸餾損失進(jìn)行蒸餾更新,完成壓縮。
6、優(yōu)選地,進(jìn)行自適應(yīng)剪枝包括以下步驟:
7、101、在transformer的正向傳遞中,將n個(gè)向量作為線性層的輸入,線性層輸出m個(gè)向量,則線性層的第j個(gè)輸出表示為:
8、
9、其中,aj表示一個(gè)線性層的第j個(gè)輸出;mj,k表示掩碼矩陣第j行、第列的元素值,當(dāng)mj,k=0時(shí)表示線性層第k個(gè)輸入與第j個(gè)輸出之間的映射權(quán)重被剪枝,當(dāng)mj,k=1時(shí)表示線性層第k個(gè)輸入與第j個(gè)輸出之間的映射權(quán)重被保留;wj,k為線性層第k個(gè)輸入與第j個(gè)輸出之間的映射權(quán)重;xk表示線性層的第k個(gè)輸入值;
10、102、在反向損失傳播中,掩碼矩陣二元階躍函數(shù)的梯度處處為0,線性層的顯著性得分在線性層的學(xué)習(xí)過(guò)程可以表示為:
11、
12、其中,l表示transformer網(wǎng)絡(luò)的損失函數(shù),mj表示一個(gè)線性層的第j個(gè)輸出向量對(duì)應(yīng)的掩碼向量,當(dāng)mj=1時(shí)表示線性層第j個(gè)輸出向量被保留,當(dāng)mj=0時(shí)表示線性層第j個(gè)輸出向量被剪枝;sj表示一個(gè)線性層的第j個(gè)輸出向量對(duì)應(yīng)的顯著性得分;
13、201、在transformer的正向傳遞中,掩碼多頭自注意力的輸出表示為:
14、
15、其中,msa表示輸入數(shù)據(jù)對(duì)應(yīng)的掩碼多頭自注意力的輸出;h為自注意力頭數(shù);wproj為線性變換矩陣;attnh為自注意力機(jī)制中第h個(gè)注意力頭的輸出;mh表示一個(gè)注意力層的第h注意力頭對(duì)應(yīng)的掩碼向量;
16、202、掩碼多頭自注意力的顯著性得分對(duì)應(yīng)的學(xué)習(xí)過(guò)程包括:
17、
18、其中,sh表示一個(gè)注意力層的第h個(gè)注意力頭對(duì)應(yīng)的顯著性得分;
19、301、對(duì)transformer網(wǎng)絡(luò)進(jìn)行剪枝,剪枝時(shí)使用增廣拉格朗日方法在閾值參數(shù)上構(gòu)造正則化,在閾值參數(shù)上構(gòu)造正則化的懲罰值表示為:
20、
21、其中,lp表示正則化的懲罰值;λ1、λ2為2個(gè)拉格朗日乘子;rt為目標(biāo)剪枝比;r為當(dāng)前剪枝比;
22、302、transformer網(wǎng)絡(luò)中第l層的閾值參數(shù)βl的學(xué)習(xí)過(guò)程表示為:
23、
24、其中,nl為transformer網(wǎng)絡(luò)中第l層的參數(shù)個(gè)數(shù);σ(βl)表示對(duì)閾值參數(shù)βl進(jìn)行標(biāo)準(zhǔn)化;n為transformer網(wǎng)絡(luò)中所有層參數(shù)總個(gè)數(shù)。
25、優(yōu)選地,根據(jù)全局蒸餾損失和交點(diǎn)蒸餾損失進(jìn)行蒸餾更新的過(guò)程包括以下步驟:
26、比較教師網(wǎng)絡(luò)中模型輸出的特征圖與真實(shí)目標(biāo)的特征圖,分別計(jì)算得到用于分離前景的二進(jìn)制掩碼和用于分離背景的縮放掩碼;
27、計(jì)算模型輸出特征圖的空間和通道注意力掩碼;
28、利用教師網(wǎng)絡(luò)的掩碼引導(dǎo)學(xué)生網(wǎng)絡(luò)進(jìn)行焦點(diǎn)蒸餾損失的學(xué)習(xí);
29、捕捉教師網(wǎng)絡(luò)中模型輸出的特征圖的全局信息,通過(guò)全局蒸餾學(xué)習(xí)彌補(bǔ)焦點(diǎn)蒸餾學(xué)習(xí)中缺失的目標(biāo)與整體圖像的交互信息。
30、優(yōu)選地,用于分離前景的二進(jìn)制掩碼表示為:
31、
32、其中,mi,j表示二進(jìn)制掩碼第i行、第j列的值,(i,j)表示教師網(wǎng)絡(luò)中模型輸出的特征圖中第i行、第j列的像素點(diǎn),r表示真實(shí)目標(biāo)對(duì)應(yīng)的檢測(cè)框。
33、優(yōu)選地,用于分離背景的縮放掩碼表示為:
34、
35、其中,(i,j)表示教師網(wǎng)絡(luò)中模型輸出的特征圖中第i行、第j列的像素點(diǎn),r表示真實(shí)目標(biāo)對(duì)應(yīng)的檢測(cè)框,hr表示檢測(cè)框r的高,wr表示檢測(cè)框r的寬;h表示教師網(wǎng)絡(luò)中模型輸出的特征圖的高,w表示教師網(wǎng)絡(luò)中模型輸出的特征圖的寬。
36、優(yōu)選地,計(jì)算模型輸出特征圖的空間和通道注意力掩碼包括:
37、as(f)=h·w·softmax(gs(f)/t)
38、
39、ac(f)=c·softmax(gc(f)/t)
40、
41、其中,as(f)表示教師網(wǎng)絡(luò)中模型輸出的特征圖f的空間注意力掩碼,h表示教師網(wǎng)絡(luò)中模型輸出的特征圖的高,w表示教師網(wǎng)絡(luò)中模型輸出的特征圖的寬,gs(f)表示用于分離前景的二進(jìn)制掩碼,t為蒸餾溫度超參數(shù),c表示教師網(wǎng)絡(luò)中模型輸出的特征圖的通道數(shù),fc表示特征圖第c個(gè)通道的特征值,c∈{1,2,…,c},|·|表示求絕對(duì)值;ac(f)表示教師網(wǎng)絡(luò)中模型輸出的特征圖f的通道注意力掩碼,gc(f)表示用于分離背景的縮放掩碼,fi,j表示特征圖中第i行、第j列像素點(diǎn)的像素值。
42、優(yōu)選地,利用教師網(wǎng)絡(luò)的掩碼引導(dǎo)學(xué)生網(wǎng)絡(luò)進(jìn)行焦點(diǎn)蒸餾損失學(xué)習(xí)的過(guò)程中,損失函數(shù)表示為:
43、lfocal=lfea+lat
44、
45、其中,lfocal為焦點(diǎn)蒸餾損失學(xué)習(xí)的損失函數(shù);lfea為焦點(diǎn)蒸餾損失學(xué)習(xí)中的特征損失,lat為焦點(diǎn)蒸餾損失學(xué)習(xí)中的注意損失;γ表示損耗的可學(xué)習(xí)參數(shù);表示教師網(wǎng)絡(luò)的空間注意力掩碼與學(xué)生網(wǎng)絡(luò)的空間注意力掩碼之間的l1損失;表示教師網(wǎng)絡(luò)的通道注意力掩碼與學(xué)生網(wǎng)絡(luò)的通道注意力掩碼之間的l1損失;α、β表示可學(xué)習(xí)參數(shù);h、w、c分別為教師網(wǎng)絡(luò)中模型輸出的特征圖的高、寬、通道數(shù);mi,j表示二進(jìn)制掩碼矩陣中第i行、第j列的值;si,j表示縮放掩碼矩陣中第i行、第j列的值;表示空間注意力掩碼中第i行、第j列的值;表示第k個(gè)通道注意力掩碼的值;表示教師模型中transformer編碼器輸出特征圖的第k個(gè)通道中第i行、第j列的值;f(·)為將學(xué)生網(wǎng)絡(luò)重塑為教師網(wǎng)絡(luò)的自適應(yīng)層;表示學(xué)生模型中transformer編碼器輸出特征圖的第k個(gè)通道中第i行、第j列的值。
46、優(yōu)選地,通過(guò)構(gòu)建gcblock捕捉教師網(wǎng)絡(luò)中模型輸出的特征圖的全局信息,gcblock捕獲全局信息的過(guò)程包括:
47、
48、其中,r(f)表示教師網(wǎng)絡(luò)中模型輸出的特征圖f的全局信息;wv2(·)、wv1(·)、wk(·)表示卷積操作;relu(·)為relu激活函數(shù);ln(·)表示層歸一化函數(shù);fj表示教師網(wǎng)絡(luò)中模型輸出的特征圖f中第j個(gè)像素,np表示教師網(wǎng)絡(luò)中模型輸出的特征圖f中像素的個(gè)數(shù)。
49、優(yōu)選地,通過(guò)全局蒸餾學(xué)習(xí)彌補(bǔ)焦點(diǎn)蒸餾學(xué)習(xí)中缺失的目標(biāo)與整體圖像的交互信息時(shí),采用的損失函數(shù)表示為:
50、lglobal=λ·∑(r(ft)-r(fs))2
51、其中,lglobal表示全局蒸餾學(xué)習(xí)的損失函數(shù);λ表示平衡損失的超參數(shù);r(ft)表示教師網(wǎng)絡(luò)transformer編碼器輸出的特征圖輸入進(jìn)gcblock捕獲的特征圖全局信息;r(fs)表示學(xué)生網(wǎng)絡(luò)transformer編碼器輸出的特征圖輸入進(jìn)gcblock捕獲的特征圖全局信息。
52、本發(fā)明利用一組可學(xué)習(xí)的剪枝相關(guān)參數(shù)自適應(yīng)調(diào)整transformer寬度,解決了帶有transformer結(jié)構(gòu)的模型無(wú)法有效剪枝的問(wèn)題;同時(shí)本發(fā)明分離前景和背景使學(xué)生網(wǎng)絡(luò)專(zhuān)注于教師網(wǎng)絡(luò)的關(guān)鍵像素和通道并重建不同像素之間的關(guān)系,解決了傳統(tǒng)知識(shí)蒸餾方法無(wú)法平衡前景和背景特征差異的問(wèn)題。最終本發(fā)明可以使模型在無(wú)明顯精度下滑的前提下,顯著降低模型的計(jì)算量并大幅提升模型的推理速度。