深層学習で使われる im2col を MATLAB で解説
新井仁之(早稲田大学)
公開日 2024年11月18日
改訂 2024年12月30日
CNN の畳み込み計算や pooling 計算のために,Chainer [1] などに im2col があることが良く知られている.これを MATLAB コードにしつつ im2col の解説をしたい.なお MATLAB にも im2col があるが,配列に仕方が縦を優先させる配列方法になっている.また stride = 1 の場合が扱われている.stride はたたみ込みに使う際,フィルタをずらしていく幅に相当する.このノートは配列の仕方は MATLAB 仕様(下図参照)で,stride の場合も述べていく.ここでは im2col を MATLAB に内装されている im2col と区別するため便宜上 「Im2Col」 とする.なお本ノートの MATLAB コードは,python コード([1], [2], [3], [4] )に基づいたものである. 配列を追う際に,MATLAB は縦方向優先, Python は横方向優先となるという違いがあるので注意してほしい.
序 MATLAB の im2col
MATLAB に im2col がある.このチュートリアルの一つの目的は,MATLAB の im2col(A, block_size, 'sliding') と同じ働きをして,さらに stride の場合を付加することでもある.まず MATLAB の im2col(A, block_size, 'sliding') がどのようなものかを簡単な例でみておく. A = reshape(Image1,[N N]);
A = permute(A,[2 1])
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
A_col = im2col(A, [3 3], 'sliding');
disp(A_col)
1 5 2 6
5 9 6 10
9 13 10 14
2 6 3 7
6 10 7 11
10 14 11 15
3 7 4 8
7 11 8 12
11 15 12 16
この配列の変換は次のような規則で行われている.
1. MATLAB の im2col を stride ≥ 2 に拡張する
たたみ込みニューラルネットでプーリングをする際に使われる,良く知られた方法は,im2col の赤い正方形の窓を 1 づつずらすのではなく,2 個飛びでずらすというものである(stride = 2).すなわち
stride ≥ 2 の im2col については,Python コードが良く知られている(たとえば [1], [2], [3], [4] 参照).本稿では,以下この Python コード(特に [3] を参考)を MATLAB に書き直しつつ,Im2Col の計算方法を解説していきたい.Phthonコードを知りたい方は,これらの文献を見てほしい.
2. Im2Col について
2.1 入力画像例
ここでの Im2Col への入力画像は
Image = 「画像数」 x 「チャネル数」 x 「画像高さ」 x 「画像幅」
の4階テンソルである.
まず簡単な二つの画像の例で Im2Col による変換の結果のみ見ていく.具体的なコードはそのあとに解説する.
一つ目の画像の作成をする.これは基本的には,画像1枚から成るデータである.
sq = squeeze(Image1(i,j,:,:));
disp(['Image1(',num2str(i),',',num2str(j),',:,:) の表示'])
end
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
disp(['size(Image1) = ',num2str(size(Image1))])
二つ目の画像の例は,画像数 2,チャネル 2 の合計4枚の画像から成るデータである.
Image2 = zeros(NumImage,Channel,N,N);
Image2(1,1,:,:) = A(:,:);
Image2(2,1,:,:) = B(:,:);
Image2(1,2,:,:) = C(:,:);
Image2(2,2,:,:) = D(:,:);
disp(['size(Image) = ',num2str(size(Image2))])
sq = squeeze(Image2(i,j,:,:));
disp(['Image2(',num2str(i),',',num2str(j),',:,:) の表示'])
end
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
1 0 0 -1
0 1 0 0
0 0 1 0
0 0 0 1
16 2 3 13
5 11 10 8
9 7 6 12
4 14 15 1
13 8 12 1
3 10 6 15
2 11 7 14
16 5 9 4
% Image は後で更新するのでバックアップのため Image_org で保存
2.2 Im2Col がどのように配列し直すかを結果を先に見ておく.
Stride = 1 の場合
Im2Col は入力値は
Col = Im2Col(Image, Block_size, Stride, Padding)
である.
Block_size, Stride, Padding は,どのようなサイズのフィルタとどのようなタイプのたたみ込みをするかで設定がかわる.なお Padding は画像の外側に 0 をパディングするときの幅である.ここで見る例では次のように設定する.
入力画像は先に定めた Image1 とする.
COL1 = Im2Col(Image1,Block_size,Stride,Padding);
出力のサイズと実際の出力は次のようになっている.まずは結果のみ示し,説明は後述する.
disp(COL1)
1 5 2 6
5 9 6 10
9 13 10 14
2 6 3 7
6 10 7 11
10 14 11 15
3 7 4 8
7 11 8 12
11 15 12 16
これは,先述の MATLAB の im2col と同じ結果であることが確認できる.
Image2 の場合は次のようになる.
COL = Im2Col(Image2,Block_size,Stride,Padding);
disp(COL)
1 5 2 6 1 0 0 1
5 9 6 10 0 0 1 0
9 13 10 14 0 0 0 0
2 6 3 7 0 1 0 0
6 10 7 11 1 0 0 1
10 14 11 15 0 0 1 0
3 7 4 8 0 0 -1 0
7 11 8 12 0 1 0 0
11 15 12 16 1 0 0 1
16 5 2 11 13 3 8 10
5 9 11 7 3 2 10 11
9 4 7 14 2 16 11 5
2 11 3 10 8 10 12 6
11 7 10 6 10 11 6 7
7 14 6 15 11 5 7 9
3 10 13 8 12 6 1 15
10 6 8 12 6 7 15 14
6 15 12 1 7 9 14 4
Stride = 2 の場合
プーリングに利用する場合は,Stride = 2 で適用するので,この場合も見ておく.
Block_size = [Block_height, Block_width];
Image1 の場合.
COL = Im2Col(Image1,Block_size,Stride,Padding);
disp(COL)
1 5 2 6
9 13 10 14
3 7 4 8
11 15 12 16
第1節で図示して示した配列の変換と同じであることが確認できる.
Image2 の場合.
COL = Im2Col(Image2,Block_size,Stride,Padding);
disp(COL)
1 5 2 6 1 0 0 1
9 13 10 14 0 0 0 0
3 7 4 8 0 0 -1 0
11 15 12 16 1 0 0 1
16 5 2 11 13 3 8 10
9 4 7 14 2 16 11 5
3 10 13 8 12 6 1 15
6 15 12 1 7 9 14 4
3. Im2Colを MATLABコードにする.
それでは,Im2Col の中身を説明していく.
このノートでは,Im2Col を関数ファイルとして作る.形としては
function Col = Im2Col(Image, Block_size, Stride, Padding)
のような関数として定義する.
プログラムは一般的な形で記載するが,計算を見るときは次の Stride = 1,Padding = 0 の具体的な場合で考える.
Block_size = [Block_height, Block_height];
Im2Col関数では,はじめに,入力データの Image から次を情報を呼び出しておく.
ここでは Image2 を画像として使う.
[NumImage, Channel, Image_height, Image_width] = size(Image2)
NumImage = 2
Channel = 2
Image_height = 4
Image_width = 4
Im2Col で出力される画像サイズの計算をしておく.FIlter でたたみ込んでいく.Block_size はたたみ込みの Filter サイズと同じものを取る.元画像の周囲に 0 を挿入するゼロパディングを考慮すると,たたみ込み演算で出力されるサイズは次のようになる.ただし,stride はたたみ込みでフィルタをずらしていく幅とみなすことができる.
Output_height = fix((Image_height - Block_height + 2*Padding)/Stride) +1;
Output_width = fix((Image_width - Block_width + 2*Padding)/Stride) + 1;
例えば,サイズ [4,4] の画像に対して,Block_size = [3,3] で,Stride = 1 と Stride = 2 の場合は,次のように数える.
考えている画像の場合,画像サイズは [4,4],Block_size = [2, 2],Stride = 1,Padding = 0 なので計算結果は次のようになる.
[Output_height,Output_width]
Im2Col の配列に変換するために,次のような途中経過を格納する col を準備する.
col = zeros(NumImage, Block_height*Block_width, Channel, Output_height, Output_width);
一般に配列替えは次のように行われる.
(Block_hight * Block_width) x ( Channel) x (Output_height) x (Output_width) x NumImage
Image = padarray(Image2,[0 0 Padding Padding],0,'both');
HS = Output_height*Stride;
WS = Output_width*Stride;
col(:,(h-1)*Block_width+w,:,:,:)=Image(:,:,w:Stride:w-1+WS ,h:Stride:h-1+HS);
% col(:,(h-1)*Block_width+w,:,:,:)=Image(:,:, h:Stride:h-1+HS, w:Stride:w-1+WS);
col の構造は
(NumImage) x (Block_hight * Block_width) x ( Channel) x (Output_height) x (Output_width)
となっているが,これを Image2,Stride = 1,Padding = 0 の場合にどのようになっているのかを見ておく.
for j=1:Block_height*Block_width
sq = squeeze(col(i,j,k,:,:));
disp(['col(',num2str(i),',' num2str(j),',',num2str(k),',:,:) の表示'])
これを次のように配列しなおす.
このため
(NumImage) x (Block_height * Block_width) x ( Channel) x (Output_height) x (Output_width)
を
(Block_height * Block_width) x ( Channel) x (Output_height) x (Output_width) x (NumImage)
に並べ替える
% MATLAB の im2col の配列にするには次のようにする.
Col = permute(col,[2 3 4 5 1]);
さらにこれを次のように reshape する.
Col = reshape(Col, [Channel*Block_height*Block_width NumImage*Output_height*Output_width]);
disp(Col)
1 5 9 2 6 10 3 7 11 16 5 9 2 11 7 3 10 6
5 9 13 6 10 14 7 11 15 5 9 4 11 7 14 10 6 15
2 6 10 3 7 11 4 8 12 2 11 7 3 10 6 13 8 12
6 10 14 7 11 15 8 12 16 11 7 14 10 6 15 8 12 1
1 0 0 0 1 0 0 0 1 13 3 2 8 10 11 12 6 7
0 0 0 1 0 0 0 1 0 3 2 16 10 11 5 6 7 9
0 1 0 0 0 1 -1 0 0 8 10 11 12 6 7 1 15 14
1 0 0 0 1 0 0 0 1 10 11 5 6 7 9 15 14 4
これを転置すれば Im2Col の完成.
Col = permute(Col,[2 1]);
disp(Col)
1 5 2 6 1 0 0 1
5 9 6 10 0 0 1 0
9 13 10 14 0 0 0 0
2 6 3 7 0 1 0 0
6 10 7 11 1 0 0 1
10 14 11 15 0 0 1 0
3 7 4 8 0 0 -1 0
7 11 8 12 0 1 0 0
11 15 12 16 1 0 0 1
16 5 2 11 13 3 8 10
5 9 11 7 3 2 10 11
9 4 7 14 2 16 11 5
2 11 3 10 8 10 12 6
11 7 10 6 10 11 6 7
7 14 6 15 11 5 7 9
3 10 13 8 12 6 1 15
10 6 8 12 6 7 15 14
6 15 12 1 7 9 14 4
Im2Col の関数ファイル - MATLAB仕様 -
function Col = Im2Col(Image,block,Stride,Padding)
%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Image = Image(Number of Filter, Channel, Image_height, Image_width)
% block = [number1, number2]
% Stride = number, Padding = number
% (NumImage*Output_height * Output_width) x (Channel * Block_height*Block_width)
% Output_hight = fix((Image_height-Block_height+2*Padding)/Stride)+1
% Output_width = fix((Image_width-Block_width+2*Padding)/Stride)+1
% 参考文献 [1], [2],[3], [4] の Pythonプログラムにもとづく.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%
[NumImage, Channel, Image_height, Image_width] = size(Image);
Output_hight = fix((Image_height-Block_height+2*Padding)/Stride)+1;
Output_width = fix((Image_width-Block_width+2*Padding)/Stride)+1;
col = zeros(NumImage,Block_height*Block_width,Channel,Output_hight,Output_width);
Image = padarray(Image,[0 0 Padding Padding],0,'both');
% MATLAB の im2col の配列にするには次のようにする.
col(:,(h-1)*Block_width+w,:,:,:)=Image(:,:,w:Stride:w-1+WS ,h:Stride:h-1+HS);
% この部分は,Python などの他の文献のようにするには次のようにする.
% col(:,(h-1)*Block_width+w,:,:,:)=Image(:,:, h:Stride:h-1+HS, w:Stride:w-1+WS);
% MATLAB式 im2col にするには次のようにする:
Col = permute(col,[2 3 4 5 1]);
% この部分は,Python などの他の文献のようにするには次のようにする.
% Col = permute(col,[2 3 5 4 1]);
Col = reshape(Col, [Channel*Block_height*Block_width NumImage*Output_hight*Output_width ]);
Col = permute(Col,[2 1]);
Im2Col の検証
Col2 = Im2Col(Image_org,[2 2],1,0);
disp(Col2)
1 5 2 6 1 0 0 1
5 9 6 10 0 0 1 0
9 13 10 14 0 0 0 0
2 6 3 7 0 1 0 0
6 10 7 11 1 0 0 1
10 14 11 15 0 0 1 0
3 7 4 8 0 0 -1 0
7 11 8 12 0 1 0 0
11 15 12 16 1 0 0 1
16 5 2 11 13 3 8 10
5 9 11 7 3 2 10 11
9 4 7 14 2 16 11 5
2 11 3 10 8 10 12 6
11 7 10 6 10 11 6 7
7 14 6 15 11 5 7 9
3 10 13 8 12 6 1 15
10 6 8 12 6 7 15 14
6 15 12 1 7 9 14 4
以上は stride = 1 の場合であるが,stride = 2 の場合も同様の計算で,次の結果が得られることは容易に確認できるであろう.
Col2 = Im2Col(Image_org,[2 2],2,0);
disp(Col2)
1 5 2 6 1 0 0 1
9 13 10 14 0 0 0 0
3 7 4 8 0 0 -1 0
11 15 12 16 1 0 0 1
16 5 2 11 13 3 8 10
9 4 7 14 2 16 11 5
3 10 13 8 12 6 1 15
6 15 12 1 7 9 14 4
参考文献
[1] https://docs.chainer.org/en/v7.8.1.post1/reference/generated/chainer.functions.im2col.html
[2] 斎藤康毅,ゼロから作る Deep Learning,O'REILLY, 2016.
[3] 立石賢吾,やさしく学ぶディープラーニングがわかる数学のきほん,マイナビ,2019.
[4] 我妻幸長,はじめてのディープラーニング - Python で学ぶニューラルネットワークとバックプロパゲーション,SB Creative, 2018.
Copyright © Hitoshi Arai, 2024