Neural Ordinary Differential Equationsの実装

この記事は授業「映像メディア学」の課題として作成されました.

 

NeurIPS 2018のBest PaperであるNeural Ordinary Differential Equationsというモデルがあります.非常にセンセーショナルな論文なので各所で紹介されていますが,実際の実装方法までご存じの方は少ないかと思います.そこで本記事ではPytorchを使って必要な関数をフルスクラッチで実装することでその内容を深く理解することに挑戦します.

 

Neural Ordinary Differential Equations (NeuralODE) とは

Rick T. Q. Chenらによって提案された,「層が連続的」なニューラルネットワークです.概念のインパクトとその有用性によりNeurIPS 2018のBest Paperに選ばれました.

 

NeuralODEの元になっているのはResNetです.ResNetは以下のようにある層の出力\({\bf z}_t\)と次の層の出力\({\bf z}_{t+1}\)の差分\(f\)を学習します.

\[ {\bf z}_{t+1} = {\bf z}_t + f({\bf z}_t,\theta_t).\]

この式が常微分方程式の1次近似であるオイラー法と類似していることから,著者らはメインアイディアである次の関係を導入しました.

\[ \frac{d {\bf z}(t)}{dt} = f({\bf z}(t),t,\theta) \]

つまり,ニューラルネットワークの隠れ状態を時間方向に連続的なものとしてしまい,その時間発展を記述する常微分方程式を学習するのです.これがNeuralODEの肝となります.

 

NeuralODEは既存のニューラルネットワークにレイヤーとして組み込む形で使います.推論時には常微分方程式のソルバ(ODEソルバ)によって上の常微分方程式を一定の時刻\(t\in[t_0,t_1] \)で解くことにより出力を求めます.そして後に述べるように,逆伝搬についても実は常微分方程式を解くだけで計算できます.

 

NeuralODEのメリット

NeuralODEの仕組みは非常にシンプルなものとなっています.しかしながら以下に示すような良い性質を有しているのです.

 

1つ目は推論も逆伝播も常微分方程式で記述される点です.これによりいずれの過程でも中間ノードの値をメモリに保存しておく必要がないため,メモリ効率が優れるという利点があります.また無数の研究がなされている既存のODEソルバを活用できる点も魅力的です.効率的かつ正確に解くことができる上,精度が必要な場合は時間刻みを小さくすることで適応的に精度を高めることができます.

 

2つ目は少ないパラメータで高い精度が出るという点です.論文中の実験ではResNetよりも少ないパラメータで同等の精度が出ることが示されています.理由について論文では「内部状態を連続に時間発展させるならばパラメータも急激には変わらない==時間的に近くのパラメータは自動的に関連付けられる」ためある種の正則化が働いていると説明されています(理論的な根拠は示されていないので半信半疑ですが...).

 

誤差逆伝播

NeuralODEの推論は前述の通りODEソルバを使います.この手法の巧妙な点は学習を進めるためのパラメータの更新方向もODEソルバで効率的に計算できる点にあります.

 

学習のためのコスト関数\(L\)を次のように表現します.

\[ L({\bf z}(t_1)=L\left({\bf z}(t_0)+\int_{t_0}^{t_1}f({\bf z}(t),t,\theta)dt\right)=L({\rm ODESolve}({\bf z}(t_0),f,t_0,t_1,\theta)). \]

 学習を進めるにはこのコスト関数の入力による微分\(\frac{\partial L}{\partial {\bf z}}\)とパラメータ\(\theta\)による微分\(\frac{\partial L}{\partial \theta}\)を求める必要があります.そして実はこれらの量はadjoint sensitivity methodという古くポントリャーギンによって考案された手法により効率的に計算できるのです.

 

adjoint sensitivity methodではまずadjointと呼ばれる量\( {\bf a}(t)=\frac{\partial L}{\partial {\bf z}}(t) \)を求めます.この量のダイナミクス微分の連鎖律のアナロジーとなる次の微分方程式で記述されます*1

\[ \frac{d {\bf a}(t)}{dt}=-{\bf a}(t)^T\frac{\partial f({\bf z}(t),t,\theta)}{\partial {\bf z}} \]

この式を初期条件\(\frac{ \partial L}{\partial {\bf z}}(t_1) \)から時間を遡る方向へ解くことで\( \frac{\partial L}{\partial {\bf z}}(t_0) \)を得ることができます.また\(\frac{\partial L}{\partial \theta}\)についても

\[ \frac{\partial L}{\partial \theta} =\int_{t_1}^{t_0}{\bf a}^T\frac{\partial f({\bf z}(t),t,\theta)}{\partial \theta}dt \]

という積分を解くことで求められます.

 

以上をまとめると,学習に必要な勾配を求めるには初期状態をそれぞれ\( {\bf z}(t_1), \frac{\partial L}{\partial {\bf z}}(t_1), 0\)として連立常微分方程式

\[\begin{align} \frac{d {\bf z}}{d t}&= f({\bf z}(t),t,\theta)\\ \frac{d {\bf a}(t)}{dt}&=-{\bf a}(t)^T\frac{\partial f({\bf z}(t),t,\theta)}{\partial {\bf z}}\\ \frac{d}{dt}\frac{\partial L}{\partial \theta}&={\bf a}^T\frac{\partial f({\bf z}(t),t,\theta)}{\partial \theta} \end{align} \]

を時刻\(t_1 \)から\(t_0 \)で解けばよいということになります.このとき任意のODEソルバを使うことができる上,式中の\({\bf a}^T\)がかかる項は機械学習ライブラリの自動微分によって効率的に計算することができます.

 

実装

 ここからは実際に実装したコードを通じて解説します.実装に用いた言語はPython 3.8.2でフレームワークはPytorch 1.5.0,CUDAのバージョンは10.1です*2

 

ソースコードの全体はhttps://github.com/lizaf999/blog/tree/master/neuralODEにあります.

 

モデル定義

以下のソースコードにモデルのインターフェース部分の実装を示します.モデルの大まかな枠組みとしては,ResNetのようにdown samplingを2回行い,その後本体であるODEBlockをレイヤーとして挿入しています.ODEBlockモジュールではこれまで解説してきたforward / backwardを実装した関数AdjointFunc()に,常微分方程式の右辺\( f({\bf z}(t),t,\theta) \)をモデリングする関数ODEFunc()を投げています.

 

ODEFuncについて眺めてみると,一般的な畳み込みニューラルネットワークとほとんど変わらない構造をしていることがわかります.実際にそのとおりで,関数\( f({\bf z}(t),t,\theta)\)を表現するために入力されてきたデータに時刻\(t\)を持つチャンネルを追加しているだけです.\(t\)に依存する関数の実装方法は色々あると思いますが,もっともシンプルなこの形で十分なのです.

 

続いてODEBlockの中の関数AdjointFuncについて見てみます.この関数は引数として前のレイヤーからの出力と上述のODEFuncのインスタンス積分する時刻\(t_0, t_1\)及びODEFuncのパラメータをとります.最後のODEFuncのパラメータが肝で,前項で見てきたadjoint sensitivity methodを適用するために学習する重みを全て渡す必要があります.このときPytorchのモジュールに備わっているparameters()では扱いづらいリスト形式となるため,flat_paramerts()で一つのtensorにしています.

 

1箇所謎のバグに遭遇したのでそのことを記述しておきます.down samplingの最後となる44行目に1x1の何もしない畳み込み層を挟んでいますが,これはpytorchのエラーを回避するためです.この畳み込み層を挟まないと自動微分が働くなります.公式repositoryのissueでも解決されていないようなので,なんなのでしょう....

from adjoint import AdjointFunc,flat_parameters
import torch
import torch.nn as nn

def conv3x3(in_planes, out_planes, padding=1,stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=padding)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class ODEFunc(nn.Module):
    def __init__(self,inplanes):
        super(ODEFunc, self).__init__()
        self.conv1 = conv3x3(inplanes+1,inplanes)
        self.norm1 = nn.BatchNorm2d(inplanes)
        self.conv2 = conv3x3(inplanes+1,inplanes)
        self.norm2 = nn.BatchNorm2d(inplanes)

    def forward(self,x,t):
        t_tensor1 = torch.ones((x.size()[0],1,x.size()[2],x.size()[3]),device="cuda")*t
        t_tensor1.requires_grad = False
        x = torch.cat((x,t_tensor1),dim=1)
        x = self.conv1(x)
        x = self.norm1(torch.relu(x))

        t_tensor2 = torch.ones_like(t_tensor1)*t
        t_tensor2.requires_grad = False
        x = torch.cat((x,t_tensor2),dim=1)
        x = self.conv2(x)
        x = self.norm2(torch.relu(x))
        return x

class ODEBlock(nn.Module):
    def __init__(self,inplanes):
        super(ODEBlock,self).__init__()
        self.func = ODEFunc(inplanes)

    def forward(self,x):
        x = AdjointFunc.apply(x,self.func,torch.tensor([0.0],device="cuda"),torch.tensor([1.0],device="cuda"),flat_parameters(self.func.parameters()))
        return x


class Model(nn.Module):
    def __init__(self,num_classes=10):
        super(Model,self).__init__()
        dim = 64
        self.downsampling = nn.Sequential(
            nn.Conv2d(1,dim,5,2,0),nn.BatchNorm2d(dim),nn.ReLU(inplace=True),
            nn.Conv2d(dim,dim,4,2),nn.BatchNorm2d(dim),nn.ReLU(inplace=True),
            conv1x1(dim,dim)#to avoid error
        )

        self.neuralODE = ODEBlock(dim)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

    def forward(self,x):
        x = self.downsampling(x)
        x = self.neuralODE(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
 常微分方程式の組み込み
続いて肝心なforward / backwardを行うAdjointFuncモジュールについて見ていきます.
 
forwardの実装はシンプルです.手法の説明の項で述べたとおり,常微分方程式をソルバで解くだけです.このとき使うソルバは関数odeint()に実装してあるオイラー法あるいは4次のルンゲ=クッタ法です*3.またレイヤーの最終的な出力\({\bf z}_1\)をbackwardの計算のためにコンテキストに保存しています.
 
backwardの実装は技巧的です.元々Pytorchは大抵の場合backwardを実装しなくともよしなに処理してくれますが,NeuralODEではそうはいかないため自前で実装する必要があります.実装するにあたってbackwardの書き方の情報が少なく苦戦しました....関数の大まかな流れとしてはadjoint sensitivity methodのための微分方程式aug_dynamics()を定義してそれを時刻\(t_1\)から\(t_0\)へodeint()で解けばよいということになります.ただ実際にはPytorchの自動微分を活かすために細かな配慮が必要です.というかコードの大半がフレームワークに載せるための補助的なものです.
 
特に厄介なaug_dynamics内の実装を見ていきます.この関数は引数xとして時刻\(t\)における\({\bf z},{\bf a} \frac{d L}{d \theta}\)の値を受け取ります.このとき自動微分を用いるために\({\bf z}\)のプロパティrequires_gradをTrueにする必要がありますが,実はこれだけではPytorchのレポジトリのissuesでも解決されていないバグを踏んでしまうため動きません.そこでソースコードの30行目では暫定的に\({\bf z}\)のデータを初期値とする新しい変数を作ることで問題を回避しています.自動微分を行うautogradを用いて,torch.autograd.Variableを呼ぶのがミソです.
 
また続く30行目でtorch.set_grad_enable(True)を呼ぶ必要があります.明示的にこのプロパティを設定しておかないと,ただbackwardを呼んだだけでは勾配が計算されません.
 
さて32行目では\(f({\bf z}(t),t,\theta) \)の値を評価しています.そして続く33行目では\(f\)の勾配をbackwardで計算していますが,このとき引数として\({\bf a}\)を渡しています.これにより勾配にヤコビアンをかけた値\({\bf a}(t)^T\frac{\partial f}{\partial {\bf z}} \)及び\({\bf a}(t)^T\frac{\partial f}{\partial \theta} \)が計算され各変数のgradに保存されるのです.
 
backwardの入力と出力についても触れておきましょう.この関数は入力としてコンテキストctxと逆伝播してきた前の層の勾配\(\frac{d L}{d {\bf z_1}}\)を取ります.コンテキストとはforwardから必要な情報を伝えるための変数です.backwardの出力は誤差関数をforwardの入力のうちコンテキストを除いた変数について微分したものとなります.つまりforwardの引数の順に合わせて,誤差\(L\)を微分した値を並べて返す必要があります.今回必要な勾配は\(\frac{dL}{d{\bf z}_0}\)と\(\frac{dL}{d\theta}\)のみなので,それ以外の要素についてはNoneをおいて数を合わせています.
 
ちなみにコードには一見おかしな箇所があります.学習する重みである\(\theta\)周りです.forwardの入力としてはモデルのパラメータを一つのtensorにしたtheta_flattenをとっています.しかしこの変数はforward中で使われません.またbackwardでは\(\theta\)をtheta_listとして改めて取得しています(23行目).なぜこのように入り組んだ構造になっているのでしょうか.その理由はフレームワークの仕様にあります.前述の通り学習に必要な勾配をbackwardが返せるようにforwardの引数を合わせる必要があるのです.したがってforwardでは使わない変数でも入れておく必要があります*4
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as functional

class AdjointFunc(torch.autograd.Function):
    def __init__(self):
        super(AdjointFunc,self).__init__()

    @staticmethod
    def forward(ctx,z0,func,t0,t1,theta_flatten):
        z1 = odeint(z0,func,t0,t1)
        ctx.func = func
        ctx.t0 = t0
        ctx.t1 = t1
        ctx.save_for_backward(z1.clone())
        return z1
    
    @staticmethod
    def backward(ctx,dLdz1):
        func = ctx.func
        t0 = ctx.t0
        t1 = ctx.t1
        z1= ctx.saved_tensors[0]

        theta_list = list(func.parameters())
        
        s0 = [z1.clone(),dLdz1.clone()]

        for theta_i in theta_list:
            s0.append(torch.zeros_like(theta_i))

        def aug_dynamics(x,t):
            z,a,dfdth_unused = x[0],x[1],x[2]
            z = torch.autograd.Variable(z.data,requires_grad=True)
            torch.set_grad_enabled(True)#important

            f = func(z,t)
            f.backward(a)

            adfdz = z.grad
            #z.grad.zero_()#不要

            adfdth_list = []
            for theta_i in theta_list:
                adfdth_list.append(-theta_i.grad)
                theta_i.grad.zero_()
            
            return (f, -adfdz, *adfdth_list)

        rlt = odeint(s0,aug_dynamics,t1,t0)
        #z0 = rlt[0]
        dLdz0 = rlt[1]
        dLdth0 = []
        for i in range(2,len(rlt)):
            dLdth0.append(rlt[i])
        dLdth0 = flat_parameters(dLdth0)

        return dLdz0,None,None,None,dLdth0

実験

それでは実装したモデルを用いて実験をしてみます.論文中にはいくつかの実験が記されていますが,ここではNeuralODEのコアの部分が実装できているか確認するために教師あり学習のタスクを選びます.
 
論文ではMNISTの分類問題についてResNetとの比較を行っています.この実験で確かめたいのは精度とパラメータ数の関係で,論文中のTable 1には以下のように記されています.
モデル テスト誤差 パラメータ数
ResNet 0.41% 0.60M
NeuralODE 0.42% 0.22M
表から,NeuralODEはResNetより少ないパラメータで同等の精度が出ることを主張しています.
 
実験した結果は次の表のようになりました.論文中には各層のユニット数などは書いていないので,実装に当たり適当な値で代替しています.実験用のソースコードは上述のgithubにおいてあります.またResNetとしてはpytorch/visionにある実装を改変して用いました.
モデル テスト誤差 パラメータ数
ResNet 0.88% 0.15M
NeuralODE 0.82% 0.14M
実験の結果,NeralODEはResNetと同じぐらいのパラメータ数で同程度の精度を示しました.つまり論文のデータとはずれた結果となりました.しかし実装の本質的な部分は問題ないと考えています.以下結果についての考察です.
 
まずパラメータ数の違いについてです.モデルの詳細が記されていない以上パラメータ数を論文と完全に一致させることは出来ないので,実験ではNeuralODEのパラメータ数を論文とだいたい合わせる方針を取りました.ResNetのパラメータ数はNeuralODEとだいたい合わせています.そしてResNetとNeuralODEの精度がほぼ同等という実験結果が出ました.このことから,実装したNeuralODEのコアの部分はきちんと働いておりResNetの代わりになっていると推測します.また,論文ではResNetよりも少ないパラメータ数で同等の精度が出ると主張していますが,それは根拠に欠ける主張です.本当に必要なのは今回やったような同程度のパラメータ数で精度に違いが出るかという実験だと思います.つまり論文でやっている実験ではResNetのパラメータ数が0.2M程度のときの精度が不明な点が問題だと考えています.
 
つづいて精度についてです.これはハイパーパラメータのチューニング次第だと考えられます.もともと誤差1%を切ってますので,MNISTでこれ以上の精度勝負をすることは本質的ではないと考え今回はそこまでチューニングしていません.調整の余地としては,層の数やノード数,kernelのサイズ等に加えて,NeuralODEに特有であるODEソルバのハイパーパラメータが考えられます.
 
実験結果をまとめると,NeuralODEは常微分方程式に基づいてあり全く違う仕組みでありながらも,ResNetに比べて少なくとも同等の性能が出ていると言えます.このときNeuralODEのforward / backwardでは中間層の値を保持する必要がありませんから,大変メモリ効率が良いモデルです.
 

発展

NeuralODEは様々な発展が考えられているポテンシャルのあるモデルです.
 
今回取り上げた分類問題の他にも,normalizing flowを代替したり時系列データのダイナミクスをモデル内部に再現したりできます.
 
また7月中旬に開催されていたICML2020でもNormalizing FlowのワークショップでNeuralODEを拡張した手法がいくつも提案されていました.
 

所感とまとめ

常微分方程式によって特徴量を計算できる点が非常に面白いと感じました.連続的な力学系によって時間発展させることにより計算が進むという性質は,natural computingと呼ばれるような,半導体以外の材料で作り上げる計算機の発展に寄与すると思います.このブログの記事の傾向に現れていますが,私は情報や計算と物理を結ぶ分野が好きなので,NeuralODEは見逃せません.近年注目の量子情報あるいは量子コンピュータでは物理的な変換によって計算が定義されます.なので力学系と計算の理解が量子コンピュータの研究に繋がっていくと面白いですね.
 
個人的に気になるのは,やはりディープラーニングの理論的な解析が不十分な点です.NeuralODEも例外ではなく,なぜうまくいくのか定性的なアイディアは書かれていても本質的な原理は不明です.例えばResNetはデータがいくつもの経路を通るアンサンブル学習を行っているという論文があります*5.しかしResNetの一般化であるNeuralODEには同様の解析は通用しない点で,そのような解析は未だ解釈に過ぎないと思います.やはりディープラーニングについてはまだまだ分かっていないことばかりです.そしてそのような曖昧な状況で精度勝負を行いState-of-the-Artを目指す研究の流れには強い危機感を覚えます.
 
また,今回初めてディープラーニングを実装してみましたが,全体的にフレームワークに載せるのが厄介だと感じました.GPUを適切に活用したり複雑なモデルを構築するためにフレームワークは非常に強い力を発揮してくれますが,どうしてもフレームワークの(本質的ではない)仕様に時間をとられがちです.
 
最後に,論文の数式に簡単な間違いがあるように思えるのでここで指摘しておきます.誤差逆伝播のためのadjoint sensitive methodで\(\frac{dL}{d\theta}\)を求めるときに論文中の式(5)の符号が間違っています.論文の次ページに乗っているアルゴリズムではODEソルバによって積分
\[\int_{t_1}^{t_0}-{\bf a}(t)^T\frac{\partial f}{\partial \theta}dt\]
を解きますが,式(5)では
\[\int_{t_1}^{t_0}{\bf a}(t)^T\frac{\partial f}{\partial \theta}dt\]
 となっています.他の量の符号や実装してみた結果を踏まえると前者が正しいはずです.このあたり積分の方向がややこしいため表記ミスをしたのだと思いますが,論文中のAlgorithm 1には正しいアルゴリズムが記されているので問題なかったのでしょう.単純に私が勘違いしてるだけの可能性も依然としてあり得ますが....

 

*1:導出は自明ではありません.詳細は論文の付録に記載されています.

*2:PytorchとCUDAのバージョンによっては一部の自動微分の実装でエラーが出て実行できませんでした.

*3:scipyのodeintではありません.GPUを使うためにtorch.tensorを使って自前で実装する必要があります

*4:本当はコンテキストctx経由でbackwardに\(\theta\)を伝えたいところですがうまく動きませんでした.

*5:Veit et al.. Residual networks behave like ensembles of relatively shallow networks, NeurIPS 2016.