Im2Col を用いた最大プーリングを MATLAB で解説
新井仁之(早稲田大学)
公開日 2024年11月18日
Ver. 1.2 2024年12月20日
Ver. 1.3 2025年1月1日
ここでは解説
「深層学習で使われる im2col を MATLABで解説」(http://www.araiweb.matrix.jp/Program/Im2Col_tutorial2.html)で述べた Im2Col を使った最大プーリングの方法について解説する.なお本ノートの MATLAB コードは,python コード([1], [2], [3],特に [2] )に基づいたものである.
まず最大プーリングとは,たとえば 2 x 2 の小ブロックの中の最大値のみを抜き出して,並べる操作のことである.元の画像が 2M x 2M のとき,あるいは 2M+1 x 2M+1 のときは,M x M の画像となる.たとえば次のようにである.
もちろん別のブロックの分け方,スライドの仕方でもよい.
Pooling はたたみ込み層で行うが,たたみ込み層では入力画像数が複数ある.単純な具体例で操作がどのようなものかを見るため,次の例で考えていく.
この例では入力データは3個あるものとする.
この例では,入力データは
[ Number of Filters, Number of Channels, Image_height, Image_width]
の4階テンソルであるとする.深層学習では,Channels は前層から受け渡されるデータの数を表し,Number of Filters は考えている層内のニューロンの個数を表す.
特に [3,2,6,6] の場合で見ていく.
まずこの層への仮想的な入力データを作る.
Image = zeros(NumFilters,NumChannels,6,6);
Image_org = Image; %保存のため
disp(['Image(',num2str(i),',',num2str(j),',:,:) = '])
end
1 2 3 4 5 6
7 8 9 10 11 12
13 14 15 16 17 18
19 20 21 22 23 24
25 26 27 28 29 30
31 32 33 34 35 36
1 2 3 4 5 6
7 8 9 10 11 12
13 14 15 16 17 18
19 20 21 22 23 24
25 26 27 28 29 30
31 32 33 34 35 36
35 1 6 26 19 24
3 32 7 21 23 25
31 9 2 22 27 20
8 28 33 17 10 15
30 5 34 12 14 16
4 36 29 13 18 11
35 1 6 26 19 24
3 32 7 21 23 25
31 9 2 22 27 20
8 28 33 17 10 15
30 5 34 12 14 16
4 36 29 13 18 11
1 1 1 1 1 1
1 1 1 1 1 1
1 1 1 1 1 1
1 1 1 1 1 1
1 1 1 1 1 1
1 1 1 1 1 1
1 1 1 1 1 1
1 1 1 1 1 1
1 1 1 1 1 1
1 1 1 1 1 1
1 1 1 1 1 1
1 1 1 1 1 1
%このノートでは Padding = 0, Stride = 2 の場合に対応
Image = padarray(Image,[0 0 Padding Padding],0,'both');
%% 今の場合は Padding = 0 なので変化なし.
% disp(['Image(',num2str(i),',',num2str(j),',:,:) =']);
% squeeze(Image(i,j,:,:))
[NumImage, Channel, Image_height, Image_width] = size(Image)
NumImage = 3
Channel = 2
Image_height = 6
Image_width = 6
2 x 2 のブロック内で,最大をとる最大プーリングを考える.ブロックを設定する.
最大プーリングで出力される画像のサイズは,たとえば入力データが 2M x 2N ならば M x N になる.これは (2M-2)/2+1=M,(2N-2)/2+1=N である.奇数 2M+1 の場合も同様の値になる.実際,
(2M+1-2)/2+1= M-1+1=M
Output_height = fix((Image_height-block_height+2*Padding)/Stride)+1;
Output_width = fix((Image_width-block_width+2*Padding)/Stride)+1;
disp(['[Output_height, output_width] = ','[', num2str(Output_height),',',num2str(Output_width),']']);
[Output_height, output_width] = [3,3]
Image の Im2Col は次のようなものになっている.
Col = Im2Col(Image,[block_height,block_width],Stride,Padding);
% [NumImage*Output_height*Output_width Channel*block_height*block_width]
% = [3*3*3 x 2*2*2] = [27 8]
disp(Col);
1 2 7 8 1 2 7 8
3 4 9 10 3 4 9 10
5 6 11 12 5 6 11 12
13 14 19 20 13 14 19 20
15 16 21 22 15 16 21 22
17 18 23 24 17 18 23 24
25 26 31 32 25 26 31 32
27 28 33 34 27 28 33 34
29 30 35 36 29 30 35 36
35 1 3 32 35 1 3 32
6 26 7 21 6 26 7 21
19 24 23 25 19 24 23 25
31 9 8 28 31 9 8 28
2 22 33 17 2 22 33 17
27 20 10 15 27 20 10 15
30 5 4 36 30 5 4 36
34 12 29 13 34 12 29 13
14 16 18 11 14 16 18 11
1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1
Col = permute(Col,[2 1]);
% [Channel*block_height*block_width NumImage*Output_height*Output_width]
disp(Col)
1 3 5 13 15 17 25 27 29 35 6 19 31 2 27 30 34 14 1 1 1 1 1 1 1 1 1
2 4 6 14 16 18 26 28 30 1 26 24 9 22 20 5 12 16 1 1 1 1 1 1 1 1 1
7 9 11 19 21 23 31 33 35 3 7 23 8 33 10 4 29 18 1 1 1 1 1 1 1 1 1
8 10 12 20 22 24 32 34 36 32 21 25 28 17 15 36 13 11 1 1 1 1 1 1 1 1 1
1 3 5 13 15 17 25 27 29 35 6 19 31 2 27 30 34 14 1 1 1 1 1 1 1 1 1
2 4 6 14 16 18 26 28 30 1 26 24 9 22 20 5 12 16 1 1 1 1 1 1 1 1 1
7 9 11 19 21 23 31 33 35 3 7 23 8 33 10 4 29 18 1 1 1 1 1 1 1 1 1
8 10 12 20 22 24 32 34 36 32 21 25 28 17 15 36 13 11 1 1 1 1 1 1 1 1 1
これを reshape して
(「block_height」*「block_width」)x (「Channel*NumImage」*「Output_height」*「Output_width」)
にする:
Col = reshape(Col,[block_height*block_width Channel*NumImage*Output_height*Output_width]);
disp(Col)
1 1 3 3 5 5 13 13 15 15 17 17 25 25 27 27 29 29 35 35 6 6 19 19 31 31 2 2 27 27 30 30 34 34 14 14 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
2 2 4 4 6 6 14 14 16 16 18 18 26 26 28 28 30 30 1 1 26 26 24 24 9 9 22 22 20 20 5 5 12 12 16 16 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
7 7 9 9 11 11 19 19 21 21 23 23 31 31 33 33 35 35 3 3 7 7 23 23 8 8 33 33 10 10 4 4 29 29 18 18 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
8 8 10 10 12 12 20 20 22 22 24 24 32 32 34 34 36 36 32 32 21 21 25 25 28 28 17 17 15 15 36 36 13 13 11 11 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
Image の列に関する最大値を Max_values,最大値の位置を Max_positions とする.
[Max_values,Max_positions] = max(Col);
disp(Max_values);
8 8 10 10 12 12 20 20 22 22 24 24 32 32 34 34 36 36 35 35 26 26 25 25 31 31 33 33 27 27 36 36 34 34 18 18 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
disp(Max_positions);
4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 1 1 2 2 4 4 1 1 3 3 1 1 4 4 1 1 3 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
Pooling_Image = reshape(Max_values,[Channel Output_width Output_height NumImage]);
% この reshape については Appendix 2 を参照.
disp(['Pooling_Image(',num2str(i),',:,:,',num2str(j),') =']);
disp(squeeze(Pooling_Image(i,:,:,j)));
end
8 20 32
10 22 34
12 24 36
35 31 36
26 33 34
25 27 18
8 20 32
10 22 34
12 24 36
35 31 36
26 33 34
25 27 18
Pooling_Image = permute(Pooling_Image,[4 1 3 2]);
%[NumImages Channel Output_height Output_width
% この reshape については Appendix 2 を参照.
この reshape の結果は:
disp(['Pooling_Image(',num2str(i),',',num2str(j),',:,:) =']);
squeeze(Pooling_Image(i,j,:,:))
end
8 10 12
20 22 24
32 34 36
8 10 12
20 22 24
32 34 36
35 26 25
31 33 27
36 34 18
35 26 25
31 33 27
36 34 18
以上の操作をまとめて関数ファイルとする.
最大プーリングの関数 Pooling Max
function [Pooling_Image,Max_positions] = Pooling_Max(Image,block,Stride,Padding)
% 入力タイプ Images(Number of Images, Number of Channels, Image_height, Image_width)
% 出力タイプ Pooling_image(Number of Images Number of Channels Output_height Output_width)
Image = padarray(Image,[0 0 Padding Padding],0,'both');
[NumImage, Channels, Image_height, Image_width] = size(Image);
Output_height = fix((Image_height - block_height + 2*Padding)/Stride)+1;
Output_width = fix((Image_width - block_width + 2*Padding)/Stride)+1;
Col = Im2Col(Image,[block_height,block_width],Stride,Padding);
Col = permute(Col,[2 1]);
Col = reshape(Col,[block_height*block_width Channels*NumImage*Output_height*Output_width]);
[Max_values,Max_positions] = max(Col);
Pooling_Image = reshape(Max_values,[Channels Output_height Output_width NumImage]);
Pooling_Image = permute(Pooling_Image,[4 1 3 2]);
% Pooling image のサイズ:[NumImage Channel Output_height Output_width]
Max_Pooling を使って計算する
[Pooling_Image,Max_Positions] = Pooling_Max(Image_org,[2 2],2,0);
disp(['Pooling_Image(',num2str(i),',',num2str(j),',:,:) =']);
squeeze(Pooling_Image(i,j,:,:))
end
8 10 12
20 22 24
32 34 36
8 10 12
20 22 24
32 34 36
35 26 25
31 33 27
36 34 18
35 26 25
31 33 27
36 34 18
Appeddix 1
function Col = Im2Col(Image,block,Stride,Padding)
%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Image = Image(Number of Filter, Channel, Image_height, Image_width)
% block = [number1, number2]
% Stride = number, Padding = number
% (NumImage*Output_height * Output_width) x (Channel * Block_height*Block_width)
% Output_hight = fix((Image_height-Block_height+2*Padding)/Stride)+1
% Output_width = fix((Image_width-Block_width+2*Padding)/Stride)+1
% 次の [A], [B], [C], [D] の Pythonプログラムにもとづく.
% [A] https://docs.chainer.org/en/v7.8.1.post1/reference/generated/chainer.functions.im2col.html
% [B] 斎藤康毅,ゼロから作る Deep Learning,O'REILLY, 2016.
% [C] 立石賢吾,やさしく学ぶディープラーニングがわかる数学のきほん,マイナビ,2019.
% [D] 我妻幸長,はじめてのディープラーニング - Python で学ぶニューラルネットワークとバックプロパゲーション,SB Creative, 2018.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%
[NumImage, Channel, Image_height, Image_width] = size(Image);
Output_hight = fix((Image_height-Block_height+2*Padding)/Stride)+1;
Output_width = fix((Image_width-Block_width+2*Padding)/Stride)+1;
col = zeros(NumImage,Block_height*Block_width,Channel,Output_hight,Output_width);
Image = padarray(Image,[0 0 Padding Padding],0,'both');
% MATLAB の im2col の配列にするには次のようにする.
%col(:,(h-1)*Block_width+w,:,:,:)=Image(:,:,w:Stride:w-1+WS ,h:Stride:h-1+HS);
% この部分は,Python などの他の文献のようにするには次のようにする.
col(:,(h-1)*Block_width+w,:,:,:)=Image(:,:, h:Stride:h-1+HS, w:Stride:w-1+WS);
% MATLAB式 im2col にするには次のようにする:
% Col = permute(col,[2 3 4 5 1]);
% この部分は,Python などの他の文献のようにするには次のようにする.
Col = permute(col,[2 3 5 4 1]);
Col = reshape(Col, [Channel*Block_height*Block_width NumImage*Output_hight*Output_width ]);
Col = permute(Col,[2 1]);
Appendix 2 MATLAB の reshape について
disp(A)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
A = reshape(A,[2 3 3 3]); %[Channel Output_width Output_height NumImage]
disp(['A(',num2str(i),',:,:,',num2str(j),') =']);
disp(squeeze(A(i,:,:,j)));
end
19 25 31
21 27 33
23 29 35
37 43 49
39 45 51
41 47 53
20 26 32
22 28 34
24 30 36
38 44 50
40 46 52
42 48 54
A = permute(A,[4 1 3 2]);%[Channel Output_height Output_width NumImage]
disp(['A(',num2str(i),',',num2str(j),',:,:) =']);
disp(squeeze(A(i,j,:,:)));
end
19 21 23
25 27 29
31 33 35
20 22 24
26 28 30
32 34 36
37 39 41
43 45 47
49 51 53
38 40 42
44 46 48
50 52 54
参考文献
[1] 斎藤康毅,ゼロから作る Deep Learning,O'REILLY, 2016.
[2] 立石賢吾,やさしく学ぶディープラーニングがわかる数学のきほん,マイナビ,2019.
[3] 我妻幸長,はじめてのディープラーニング - Python で学ぶニューラルネットワークとバックプロパゲーション,SB Creative, 2018.