お手軽線形補間クラスの実装

うちの研究室では組み込み開発の占める割合がすごく大きいので,自分も含めてメンバはSTLだったりboostだったりC++0xだったりに馴染んでいない.それが研究室全体の生産性を妨げている部分があるなーと感じていたので,修論が終わってから研究室内でC++講座的なものを開催している.そこで題材にした簡単な線形補間クラスのことをせっかくだから日記のネタに.

お題

以下のように書ける,線形補間クラスが欲しい.補外は考えない.スレッドセーフも考えないし,桁落ちもとりあえず考えない.ただし,なるべく色んな型に対して使えるジェネリックなクラスにしたい.

int main(void){
    LI<double, double> l;
 
    // (x,y)を登録
    l.append(0, 0.5);
    l.append(1, 1);

    // 登録済みのkeyに対しては対応するvalueを返す
    std::cout << l[0] << std::endl; // "0.5"
    std::cout << l[1] << std::endl; // "1"
    // 登録されていないものは線形補間の結果を返す
    std::cout << l[0.5] << std::endl; // "0.75"
}

STLを知っていると,operator[]の呼び出しのうち,最初の2つはmapの動作そのものであることにすぐ気づく.最後の1つだけがmapと違う動作で,これをどう表現するかという問題*1

書いたもの

やってることはシンプルだけど,mapやSTLアルゴリズムを知らずに全部自力で実装するのは大変.STLC++0xを使うとこんな感じに短く書ける.

#include<assert.h>
#include<map>
#include<algorithm>
#include<functional>

template<typename X = double, typename Y = double>
class LI {
public:
    typedef std::map<X, Y> Points;

    LI(){}

    template<class Iter>
    LI(Iter first, Iter last) : points_(first, last) {}

    void append(X x, Y y){
        points_[x] = y;
    }

    Y operator[] (X x) const{
        assert(!points_.empty());
        //xがinsert済みのものと一致していればその値を返し,そうでなければ補間計算の結果を返す
        return points_.find(x) == points_.end() ? interpolate_(x) : points_.at(x);
    }

private:
    static bool greater_(typename Points::value_type p, X x){
        return p.first > x ? true : false;
    }

    Y interpolate_(X x) const{
        //xを超える最小のpointを指すイテレータ
        auto it_h = std::find_if(points_.begin(), points_.end(), 
                                 std::bind2nd(std::ptr_fun(greater_), x));

        //補外が必要になるようなケースはassertで弾く
        assert(it_h != points_.begin() && it_h != points_.end());

        //xを下回る最大のpointを指すイテレータ
        auto it_l = it_h;
        it_l--;

        return linear_((*it_l).first, (*it_l).second, (*it_h).first, (*it_h).second, x);
    }

    Y linear_(X x1, Y y1, X x2, Y y2, X x) const{
        return (x - x1) * (y2 - y1) / (x2 - x1) + y1;
    }

    Points points_;
};

ここでは補間結果のキャッシュは取らず,逐一計算を実行している.std::bind2nd(std::ptr_fun(greater_), x)とか,greater_の引数が非対称だったりするのがややキモいけど,他は素直な実装.STL無しで同等の安全性を持ったクラスを書くのは相当に大変だろう.
mapのメンバ関数atはC++0xから追加されたもの.operator[]と比べると

  • 渡されたKeyに対応する要素が無い場合,例外を投げる
  • constメンバである

という2点が異なる.ここではstd::findの後ろで呼び出すために例外が投げられない*2こと,「constが使えるときには必ず使え」の鉄則から,atを使用してLI::operator[]をconstにしている.

なるべくジェネリック

ところで,わざわざstaticメンバ関数greater_なんてものを定義しなくても,STLの関数アダプタstd::greaterを使ったほうがクラスがすっきり書ける.

#include<assert.h>
#include<map>
#include<algorithm>
#include<functional>

template<typename X = double, typename Y = double>
class LI {
public:
    typedef std::map<X, Y> Points;

    LI(){}

    template<class Iter>
    LI(Iter first, Iter last) : points_(first, last) {}

    void append(X x, Y y){
        points_[x] = y;
    }

    Y operator[] (X x) const{
        assert(!points_.empty());
        //xがinsert済みのものと一致していればその値を返し,そうでなければ補間計算の結果を返す
        return points_.find(x) == points_.end() ? interpolate_(x) : points_.at(x);
    }

private:
    Y interpolate_(X x) const{
        //xを超える最小のpointを指すイテレータ
        auto it_h = std::find_if(points_.begin(), points_.end(), 
                                 std::bind2nd(std::greater<Points::value_type>(), std::make_pair(x, Y())));

        //補外が必要になるようなケースはassertで弾く
        assert(it_h != points_.begin() && it_h != points_.end());

        //xを下回る最大のpointを指すイテレータ
        auto it_l = it_h;
        it_l--;

        return linear_((*it_l).first, (*it_l).second, (*it_h).first, (*it_h).second, x);
    }

    Y linear_(X x1, Y y1, X x2, Y y2, X x) const{
        return (x - x1) * (y2 - y1) / (x2 - x1) + y1;
    }

    Points points_;
};

これでも冒頭に書いたアプリケーションコードは全く問題なく動く.ところが,このバージョンは型Yに対して新たに次の制約を課す.

  • デフォルトコンストラクタが存在するか,組み込み型であること
  • operator<(const Y&,cosnt Y&)が定義されていること

std::greater >はpairのoperator<を呼び出すが,これはKeyだけでなくValueのoperator<も要求する.たとえばVCでは次のように実装されており,Y型のsecondについての比較が定義されていなければならない.

template<class _Ty1,
    class _Ty2> inline
    bool operator<(const pair<_Ty1, _Ty2>& _Left,
        const pair<_Ty1, _Ty2>& _Right)
    {	// test if _Left < _Right for pairs
    return (_Left.first < _Right.first ||
        !(_Right.first < _Left.first) && _Left.second < _Right.second);
    }

これに対して最初のバージョンでは,直感的に線形補間に必要だと思われるoperatorだけが備わっていれば良い.たとえばstd::complexや,行列ライブラリEigenのMatrixインスタンスに対しても補間が行える.

#include<iostream>
#include<Eigen/Core>
#include "LI.h"

int main(void){
    Eigen::Vector3d v1(1, 2, 3);
    Eigen::Vector3d v2(2, 4, 6);

    LI<double, Eigen::Vector3d> l;
    l.append(0, v1);
    l.append(1, v2);

    //Vector3dに対するoperator<が無いため,2番目のバージョンではコンパイルが通らない
    std::cout << l[0] << std::endl;// "1 2 3"
    std::cout << l[1] << std::endl;// "2 4 6"
    std::cout << l[0.5] << std::endl;// "1.5 3 4.5"
}

ポリシー・クラス

作成した線形補間クラスのpublicインターフェースは,appendで値のペアを登録し,operatorで補間込みの値を得るというだけだった.これは線形に限らず他の補間アルゴリズムでも共通に使える.そこで,線形補間クラスを一般化した補間クラスを設計することを考える.これはさっきのコードの一部を別のクラスに括り出すだけで簡単に実現する.

template<class X, class Y>
class Linear {
public:
    Y interpolate(X x, const std::map<X,Y>& points) const{
        auto it_h = std::find_if(points.begin(), points.end(),   
                                 std::bind2nd(std::ptr_fun(greater_), x));

        assert(it_h != points.begin() && it_h != points.end());

        auto it_l = it_h;
        it_l--;

        return linear_((*it_l).first, (*it_l).second, (*it_h).first, (*it_h).second, x);
    }

private:
    static bool greater_(typename std::map<X,Y>::value_type p, X x){
        return p.first > x ? true : false;
    }

    Y linear_(X x1, Y y1, X x2, Y y2, X x) const{
        return (x - x1) * (y2 - y1) / (x2 - x1) + y1;
    }
};

template<typename X = double, typename Y = double, template<class,class>class InterpolatingPolicy = Linear>
class Interpolation : public InterpolatingPolicy<X,Y>{
public:
    typedef std::map<X, Y> Points;

    Interpolation(){}

    template<class Iter>
    Interpolation(Iter first, Iter last) : points_(first, last) {}

    void append(X x, Y y){
        points_[x] = y;
    }

    Y operator[] (X x) const{
        assert(!points_.empty());
        //xがinsert済みのものと一致していればその値を返し,そうでなければ補間計算の結果を返す
        return points_.find(x) == points_.end() ? interpolate(x, points_) : points_.at(x);
    }

private:
    Points points_;
};

Interpolationクラスは第三引数を省略すればこれまで通り線形補間を実行する.必要であれば自分でinterpolate関数を持った補間ポリシークラスを実装し,引数に与えてやれば良い.このような外部からアルゴリズムを付与するクラスをポリシー・クラスと呼ぶ.

まとめ

STLを知っておくだけで,再利用性の高い線形補間クラスを50行に満たない簡単な実装で得ることができた.ふつうSTLの内部実装を意識する必要は無いが,今回のようにSTLアルゴリズムが暗黙に要求するインターフェースを理解しておくと,ジェネリックな設計に役立つこともある.

*1:mapの場合はValue型のデフォルト値を返す

*2:マルチスレッドを考えだすとややこしい