Neural Ordinary Differential Equationsの実装
この記事は授業「映像メディア学」の課題として作成されました.
NeurIPS 2018のBest PaperであるNeural Ordinary Differential Equationsというモデルがあります.非常にセンセーショナルな論文なので各所で紹介されていますが,実際の実装方法までご存じの方は少ないかと思います.そこで本記事ではPytorchを使って必要な関数をフルスクラッチで実装することでその内容を深く理解することに挑戦します.
Neural Ordinary Differential Equations (NeuralODE) とは
Rick T. Q. Chenらによって提案された,「層が連続的」なニューラルネットワークです.概念のインパクトとその有用性によりNeurIPS 2018のBest Paperに選ばれました.
NeuralODEの元になっているのはResNetです.ResNetは以下のようにある層の出力\({\bf z}_t\)と次の層の出力\({\bf z}_{t+1}\)の差分\(f\)を学習します.
\[ {\bf z}_{t+1} = {\bf z}_t + f({\bf z}_t,\theta_t).\]
この式が常微分方程式の1次近似であるオイラー法と類似していることから,著者らはメインアイディアである次の関係を導入しました.
\[ \frac{d {\bf z}(t)}{dt} = f({\bf z}(t),t,\theta) \]
つまり,ニューラルネットワークの隠れ状態を時間方向に連続的なものとしてしまい,その時間発展を記述する常微分方程式を学習するのです.これがNeuralODEの肝となります.
NeuralODEは既存のニューラルネットワークにレイヤーとして組み込む形で使います.推論時には常微分方程式のソルバ(ODEソルバ)によって上の常微分方程式を一定の時刻\(t\in[t_0,t_1] \)で解くことにより出力を求めます.そして後に述べるように,逆伝搬についても実は常微分方程式を解くだけで計算できます.
NeuralODEのメリット
NeuralODEの仕組みは非常にシンプルなものとなっています.しかしながら以下に示すような良い性質を有しているのです.
1つ目は推論も逆伝播も常微分方程式で記述される点です.これによりいずれの過程でも中間ノードの値をメモリに保存しておく必要がないため,メモリ効率が優れるという利点があります.また無数の研究がなされている既存のODEソルバを活用できる点も魅力的です.効率的かつ正確に解くことができる上,精度が必要な場合は時間刻みを小さくすることで適応的に精度を高めることができます.
2つ目は少ないパラメータで高い精度が出るという点です.論文中の実験ではResNetよりも少ないパラメータで同等の精度が出ることが示されています.理由について論文では「内部状態を連続に時間発展させるならばパラメータも急激には変わらない==時間的に近くのパラメータは自動的に関連付けられる」ためある種の正則化が働いていると説明されています(理論的な根拠は示されていないので半信半疑ですが...).
誤差逆伝播
NeuralODEの推論は前述の通りODEソルバを使います.この手法の巧妙な点は学習を進めるためのパラメータの更新方向もODEソルバで効率的に計算できる点にあります.
学習のためのコスト関数\(L\)を次のように表現します.
\[ L({\bf z}(t_1)=L\left({\bf z}(t_0)+\int_{t_0}^{t_1}f({\bf z}(t),t,\theta)dt\right)=L({\rm ODESolve}({\bf z}(t_0),f,t_0,t_1,\theta)). \]
学習を進めるにはこのコスト関数の入力による微分\(\frac{\partial L}{\partial {\bf z}}\)とパラメータ\(\theta\)による微分\(\frac{\partial L}{\partial \theta}\)を求める必要があります.そして実はこれらの量はadjoint sensitivity methodという古くポントリャーギンによって考案された手法により効率的に計算できるのです.
adjoint sensitivity methodではまずadjointと呼ばれる量\( {\bf a}(t)=\frac{\partial L}{\partial {\bf z}}(t) \)を求めます.この量のダイナミクスは微分の連鎖律のアナロジーとなる次の微分方程式で記述されます*1.
\[ \frac{d {\bf a}(t)}{dt}=-{\bf a}(t)^T\frac{\partial f({\bf z}(t),t,\theta)}{\partial {\bf z}} \]
この式を初期条件\(\frac{ \partial L}{\partial {\bf z}}(t_1) \)から時間を遡る方向へ解くことで\( \frac{\partial L}{\partial {\bf z}}(t_0) \)を得ることができます.また\(\frac{\partial L}{\partial \theta}\)についても
\[ \frac{\partial L}{\partial \theta} =\int_{t_1}^{t_0}{\bf a}^T\frac{\partial f({\bf z}(t),t,\theta)}{\partial \theta}dt \]
という積分を解くことで求められます.
以上をまとめると,学習に必要な勾配を求めるには初期状態をそれぞれ\( {\bf z}(t_1), \frac{\partial L}{\partial {\bf z}}(t_1), 0\)として連立常微分方程式
\[\begin{align} \frac{d {\bf z}}{d t}&= f({\bf z}(t),t,\theta)\\ \frac{d {\bf a}(t)}{dt}&=-{\bf a}(t)^T\frac{\partial f({\bf z}(t),t,\theta)}{\partial {\bf z}}\\ \frac{d}{dt}\frac{\partial L}{\partial \theta}&={\bf a}^T\frac{\partial f({\bf z}(t),t,\theta)}{\partial \theta} \end{align} \]
を時刻\(t_1 \)から\(t_0 \)で解けばよいということになります.このとき任意のODEソルバを使うことができる上,式中の\({\bf a}^T\)がかかる項は機械学習ライブラリの自動微分によって効率的に計算することができます.
実装
ここからは実際に実装したコードを通じて解説します.実装に用いた言語はPython 3.8.2でフレームワークはPytorch 1.5.0,CUDAのバージョンは10.1です*2.
ソースコードの全体はhttps://github.com/lizaf999/blog/tree/master/neuralODEにあります.
モデル定義
以下のソースコードにモデルのインターフェース部分の実装を示します.モデルの大まかな枠組みとしては,ResNetのようにdown samplingを2回行い,その後本体であるODEBlockをレイヤーとして挿入しています.ODEBlockモジュールではこれまで解説してきたforward / backwardを実装した関数AdjointFunc()に,常微分方程式の右辺\( f({\bf z}(t),t,\theta) \)をモデリングする関数ODEFunc()を投げています.
ODEFuncについて眺めてみると,一般的な畳み込みニューラルネットワークとほとんど変わらない構造をしていることがわかります.実際にそのとおりで,関数\( f({\bf z}(t),t,\theta)\)を表現するために入力されてきたデータに時刻\(t\)を持つチャンネルを追加しているだけです.\(t\)に依存する関数の実装方法は色々あると思いますが,もっともシンプルなこの形で十分なのです.
続いてODEBlockの中の関数AdjointFuncについて見てみます.この関数は引数として前のレイヤーからの出力と上述のODEFuncのインスタンス,積分する時刻\(t_0, t_1\)及びODEFuncのパラメータをとります.最後のODEFuncのパラメータが肝で,前項で見てきたadjoint sensitivity methodを適用するために学習する重みを全て渡す必要があります.このときPytorchのモジュールに備わっているparameters()では扱いづらいリスト形式となるため,flat_paramerts()で一つのtensorにしています.
1箇所謎のバグに遭遇したのでそのことを記述しておきます.down samplingの最後となる44行目に1x1の何もしない畳み込み層を挟んでいますが,これはpytorchのエラーを回避するためです.この畳み込み層を挟まないと自動微分が働くなります.公式repositoryのissueでも解決されていないようなので,なんなのでしょう....
from adjoint import AdjointFunc,flat_parameters import torch import torch.nn as nn def conv3x3(in_planes, out_planes, padding=1,stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class ODEFunc(nn.Module): def __init__(self,inplanes): super(ODEFunc, self).__init__() self.conv1 = conv3x3(inplanes+1,inplanes) self.norm1 = nn.BatchNorm2d(inplanes) self.conv2 = conv3x3(inplanes+1,inplanes) self.norm2 = nn.BatchNorm2d(inplanes) def forward(self,x,t): t_tensor1 = torch.ones((x.size()[0],1,x.size()[2],x.size()[3]),device="cuda")*t t_tensor1.requires_grad = False x = torch.cat((x,t_tensor1),dim=1) x = self.conv1(x) x = self.norm1(torch.relu(x)) t_tensor2 = torch.ones_like(t_tensor1)*t t_tensor2.requires_grad = False x = torch.cat((x,t_tensor2),dim=1) x = self.conv2(x) x = self.norm2(torch.relu(x)) return x class ODEBlock(nn.Module): def __init__(self,inplanes): super(ODEBlock,self).__init__() self.func = ODEFunc(inplanes) def forward(self,x): x = AdjointFunc.apply(x,self.func,torch.tensor([0.0],device="cuda"),torch.tensor([1.0],device="cuda"),flat_parameters(self.func.parameters())) return x class Model(nn.Module): def __init__(self,num_classes=10): super(Model,self).__init__() dim = 64 self.downsampling = nn.Sequential( nn.Conv2d(1,dim,5,2,0),nn.BatchNorm2d(dim),nn.ReLU(inplace=True), nn.Conv2d(dim,dim,4,2),nn.BatchNorm2d(dim),nn.ReLU(inplace=True), conv1x1(dim,dim)#to avoid error ) self.neuralODE = ODEBlock(dim) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(64, num_classes) def forward(self,x): x = self.downsampling(x) x = self.neuralODE(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x
常微分方程式の組み込み
import numpy as np import torch import torch.nn as nn import torch.nn.functional as functional class AdjointFunc(torch.autograd.Function): def __init__(self): super(AdjointFunc,self).__init__() @staticmethod def forward(ctx,z0,func,t0,t1,theta_flatten): z1 = odeint(z0,func,t0,t1) ctx.func = func ctx.t0 = t0 ctx.t1 = t1 ctx.save_for_backward(z1.clone()) return z1 @staticmethod def backward(ctx,dLdz1): func = ctx.func t0 = ctx.t0 t1 = ctx.t1 z1= ctx.saved_tensors[0] theta_list = list(func.parameters()) s0 = [z1.clone(),dLdz1.clone()] for theta_i in theta_list: s0.append(torch.zeros_like(theta_i)) def aug_dynamics(x,t): z,a,dfdth_unused = x[0],x[1],x[2] z = torch.autograd.Variable(z.data,requires_grad=True) torch.set_grad_enabled(True)#important f = func(z,t) f.backward(a) adfdz = z.grad #z.grad.zero_()#不要 adfdth_list = [] for theta_i in theta_list: adfdth_list.append(-theta_i.grad) theta_i.grad.zero_() return (f, -adfdz, *adfdth_list) rlt = odeint(s0,aug_dynamics,t1,t0) #z0 = rlt[0] dLdz0 = rlt[1] dLdth0 = [] for i in range(2,len(rlt)): dLdth0.append(rlt[i]) dLdth0 = flat_parameters(dLdth0) return dLdz0,None,None,None,dLdth0
実験
モデル | テスト誤差 | パラメータ数 |
ResNet | 0.41% | 0.60M |
NeuralODE | 0.42% | 0.22M |
モデル | テスト誤差 | パラメータ数 |
ResNet | 0.88% | 0.15M |
NeuralODE | 0.82% | 0.14M |
発展
所感とまとめ
*1:導出は自明ではありません.詳細は論文の付録に記載されています.
*2:PytorchとCUDAのバージョンによっては一部の自動微分の実装でエラーが出て実行できませんでした.
*3:scipyのodeintではありません.GPUを使うためにtorch.tensorを使って自前で実装する必要があります
*4:本当はコンテキストctx経由でbackwardに\(\theta\)を伝えたいところですがうまく動きませんでした.
*5:Veit et al.. Residual networks behave like ensembles of relatively shallow networks, NeurIPS 2016.