思考ノイズ

無い知恵を絞りだす。無理はしない。

<はじパタ> 10.3 階層型クラスタリング (融合型)

K-平均法(10.2)を扱った前回に引き続きまして、クラスタリングの10章。今回は階層型、融合型というような表現がされている分類方法についてみていきたいと思います。

階層型クラスタリング

この融合型の方法は「定義」により距離的に近いとされる2つのクラスを一つのクラスとして融合することを繰り返していきます。今回は引き続きアヤメのデータを使いましたが、この場合はクラスは3つとなりますので、融合をすすめていき、最終的に3つのクラスに収束したら終わりとしました。

融合するための距離の4つの「定義」

融合にはクラス間の距離的な近さを計測します。はじパタでは距離の定義について以下の4つを紹介しています。

  • 単連結法

2クラスに属する要素通しで最も近いもの通しをそのクラス間の近さと定義して、その値が小さいもの通しを融合していく方法

  • 完全連結法

2クラスに属する要素通しで最も遠いもの通しをそのクラス間の近さと定義して、その値が小さいもの通しを融合していく方法

  • 群平均法

クラスタ内のすべての要素の平均をとり、各クラスの平均値間の距離をクラスの近さと定義する。その値が小さいもの通しを融合していく方法

  • ウォード法 クラスタを融合した場合の平均値が、もとのそれぞれのクラスタ内の平均値から変化する量をクラスタの近さと定義する。その値の小さいもの通しを融合していく方法

コーディングの際の考慮すべき点

アヤメデータは150の要素から成り立つのでそれぞれの要素の距離を示す行列を持つ必要があり、150 x 150のマトリックスとなります。さらにそのうち、上記の定義により近いと判断された2点が同じクラスタとされていきます。要素に番号を振り、融合された要素をリスト化していく形で、マトリックスが縮小(=クラスタ(リスト)内の要素が増えていく)ようにして、最終的に3つのクラスタになった時点で止めるようにしました。 最小値をもとめるコードの大部分は使いまわしが可能で、変更すべき点は、マトリックスのマージのプロセスのみとなりました。

実行結果

アヤメデータ4つの特性をそれぞれ2つずつ、計6通りのセットで確認しています。評価については10.2のK-平均法と同じく正解率とF値を出すようにしています。詳細はこちら

 ===  sepal length (cm)  ===  sepal width (cm)  === 
 ***  Single Linkage Method
   * correctness --  0.35333333333333333
   * F-Value     --  0.502442878626
 ***  Complete Linkage Method
   * correctness --  0.52
   * F-Value     --  0.518563070642
 ***  Group Average Method
   * correctness --  0.7133333333333334
   * F-Value     --  0.741125961406
 ***  Ward Method
   * correctness --  0.78
   * F-Value     --  0.73182865776
 ===  sepal length (cm)  ===  petal length (cm)  === 
 ***  Single Linkage Method
   * correctness --  0.6733333333333333
   * F-Value     --  0.763119756663
 ***  Complete Linkage Method
   * correctness --  0.84
   * F-Value     --  0.825008207203
 ***  Group Average Method
   * correctness --  0.7466666666666667
   * F-Value     --  0.758906000869
 ***  Ward Method
   * correctness --  0.84
   * F-Value     --  0.825008207203
 ===  sepal length (cm)  ===  petal width (cm)  === 
 ***  Single Linkage Method
   * correctness --  0.34
   * F-Value     --  0.501625617698
 ***  Complete Linkage Method
   * correctness --  0.5866666666666667
   * F-Value     --  0.648999234975
 ***  Group Average Method
   * correctness --  0.7066666666666667
   * F-Value     --  0.745888252968
 ***  Ward Method
   * correctness --  0.8266666666666667
   * F-Value     --  0.806976194904
 ===  sepal width (cm)  ===  petal length (cm)  === 
 ***  Single Linkage Method
   * correctness --  0.66
   * F-Value     --  0.760564623487
 ***  Complete Linkage Method
   * correctness --  0.94
   * F-Value     --  0.941265889017
 ***  Group Average Method
   * correctness --  0.8733333333333333
   * F-Value     --  0.863216837848
 ***  Ward Method
   * correctness --  0.8666666666666667
   * F-Value     --  0.854987973338
 ===  sepal width (cm)  ===  petal width (cm)  === 
 ***  Single Linkage Method
   * correctness --  0.66
   * F-Value     --  0.760564623487
 ***  Complete Linkage Method
   * correctness --  0.8066666666666666
   * F-Value     --  0.738999798162
 ***  Group Average Method
   * correctness --  0.66
   * F-Value     --  0.760564623487
 ***  Ward Method
   * correctness --  0.84
   * F-Value     --  0.82020552181
 ===  petal length (cm)  ===  petal width (cm)  === 
 ***  Single Linkage Method
   * correctness --  0.6733333333333333
   * F-Value     --  0.763119756663
 ***  Complete Linkage Method
   * correctness --  0.96
   * F-Value     --  0.960162654599
 ***  Group Average Method
   * correctness --  0.96
   * F-Value     --  0.960162654599
 ***  Ward Method
   * correctness --  0.8933333333333333
   * F-Value     --  0.887000529608

考察

全体的な傾向としては単連結法が比較して正解率が良くなく、ほかの方法はケースバイケースでしょうか。そのなかでもウォード法は安定して高い値がでているように見えます。

K-平均法の時と同じようにすべての特性を扱った4次元データでの精度を確認してみましたが、やはり特徴点が顕著となる2点を使ったデータのほうが精度が高いという結果になりました。

また、K-平均法のほうが精度は高いです。結構苦労してコーディングしたのに悲しい限り。アヤメデータ以外でのK-平均法よりよくなるケースがあるのか気になるところではあります。

ソースコード

いかJupyter notebook上で動くソースです。

from sklearn import datasets
import matplotlib.pyplot as plt
from collections import Counter
#inline pyplot

class ClastaringFusion():
    def __init__(self):
        self.x = []
        self.y = []
        self.p = []
        self.t = []
        self.pt = []
        self.num = [50, 50, 50]
        iris = datasets.load_iris()
        
        self.feature_names = iris.feature_names
        features = iris.data
        targets = iris.target
        self.dmatrix = np.array([])

    def install_val(self,a=1,b=2):
        self.p = [a, b]
        self.x = features.T[self.p[0]]
        self.y = features.T[self.p[1]]
        self.t = targets
        self.distance()

        
    def distance(self, x=None, y=None):
        if x==None: x = self.x

        self.dmatrix = np.array([[np.sqrt((self.x[i] - self.x[j])**2 + (self.y[i] - self.y[j])**2 ) for i in range(150)] for j in range(150)]) 
        
        #return np.array([[np.sqrt((i - j)**2 + (i - j)**2) for i in x] for j in y])          
    def comp_fusion(self, list1, list2, fm=0):
        def ave(lst):
            x_lst = [self.x[i] for i in lst]
            y_lst = [self.y[i] for i in lst]

            m_x = sum(x_lst)/len(x_lst)
            m_y = sum(y_lst)/len(y_lst)


            return sum([(x_lst[i] - m_x)**2 + (y_lst[i] - m_y)**2 for i in range(len(lst))])
            
        if fm ==0:
            v = min([min([self.dmatrix[l1, l2] for l2 in list2]) for l1 in list1])
        elif fm==1:
            v =  max([max([self.dmatrix[l1, l2] for l2 in list2]) for l1 in list1])
        elif fm==2:
            v = sum([sum([self.dmatrix[l1, l2] for l2 in list2]) for l1 in list1]) /(len(list1) * len(list2))
        elif fm==3:
            v = ave(list1+list2) - (ave(list1) + ave(list2))    
        else:
            print("Error: Invalid number for Fusion Method")
            sys.exit(1)
            
        return v
    
    def fusion(self, fm=0, tree=None):
        '''
        fm (fusion_method)
        0: Single Linkage Method
        1: Complete Linkage Method
        2: Group Average Method
        3: Ward Method
        
        '''
        if tree==None:
            tree = [[i] for i in range(150)]
        dmtx = np.array([[self.comp_fusion(i, j, fm) for i in tree] for j in tree])

        for i in range(dmtx.shape[0]):
            dmtx[i, i] = 100
        
        val = np.argmin(dmtx)
        #print(val)
        #print(dmtx.shape)
        #print("D", dmtx[val//dmtx.shape[0], val%dmtx.shape[0]])
        minval = np.array(np.where(dmtx==dmtx[val//dmtx.shape[0], val%dmtx.shape[0]]))

        ctree = []
        checked = []
        lgth = minval.shape[1]
        #print(minval)
        #print("lgth: ", lgth)
        for i in range(lgth): 
            flag = True
            one = tree[minval[0][i]]
            two = tree[minval[1][i]]
            checked += (one + two)
            inc_num = []
            
            #print("one+two", one, two)
            #print(" ctree ", ctree)
            if minval[0][i] != minval[1][i]:

                for j, t in enumerate(ctree):
                    if set(t) & set(one) or set(t) & set(two):
                        inc_num.append(j)
                #print(" ", inc_num, " ==== ", one, two)
                #print(" ", ctree)
                if len(inc_num) == 0:
                    ctree.append(one+two)
                else:
                    ctree[inc_num[0]] = list(set(one+two+ctree[inc_num[0]]))
                    if len(inc_num) > 1:
                        for n in inc_num:
                            ctree[inc_num[0]] = list(set(ctree[inc_num[0]]+ctree[n]))
                            
                        for k, n in enumerate(inc_num[::-1]):
                            if k == len(inc_num)-1:
                                break
                            ctree.pop(n)
                    
          
        for i in tree:
            if not set(i) & set(checked):
                ctree.append(i)
        #print(ctree)
        test = []

        for t in ctree:
            test += t
        
        missing =  (set(range(150)) - set(test))
        tomatch =  (set(test) - set(range(150)))
        #print("Missing", missing)
        #print("Too Match", tomatch)
        return ctree  
    def show_all_result(self):
        fm_string = [
            "Single Linkage Method",
            "Complete Linkage Method",
            "Group Average Method",
            "Ward Method"
        ]
        for i, (a, b) in enumerate(itertools.combinations(range(4), 2)):
            
            print(" === ", self.feature_names[a], " === ", self.feature_names[b], " === ")
            self.install_val(a, b)
            #self.plot_map()

            for fm in range(4):
                tree = cf.fusion(fm)
                c = 0
                while len(tree) > 3:
                    #print("======")
                    tree = cf.fusion(fm, tree)
                    #print(len(tree))
                    #if len(tree) < 4:
                    #    break

                    c += 1
                crct, fv = cf.correct_rate(tree, False)
                print(" *** ", fm_string[fm])
                print("   * correctness -- ", crct)
                print("   * F-Value     -- ", fv)
             
        
    def correct_rate(self, tree, verbose=True):

        trees = []
        crcts = 0.0
        fvs = 0.0
        

        for p in list(itertools.permutations(range(3))):

            pt_tree = []
            for i in range(150):
                for j in range(3):
                    if i in tree[j]:
                        pt_tree.append(p.index(j))
                        break
                        
            trees.append(pt_tree)
            #print("tree", pt_tree)
                
            cnt = np.array([[Counter(pt_tree[i*50:i*50+50])[j] for i in range(3)] for j in range(3)])
            odr = [np.argmax([Counter(pt_tree[j*50:j*50+50])[i] for i in range(3)]) for j in range(3)]
            crct = []
            crctnum = 0

            for i in range(150):
                if i // 50 == pt_tree[i]:
                    crctnum += 1

            fv = 1.0

            for i in range(3):

                x = np.count_nonzero(pt_tree[i*50:i*50+50] == odr[i])
                z = np.count_nonzero(pt_tree[(i+1)%3*50:(i+1)%3*50+50] == odr[i]) + \
                    np.count_nonzero(pt_tree[(i+2)%3*50:(i+2)%3*50+50] == odr[i])
                rec = x/50
                pre = 50/(50+z)
                fv *= 2*rec*pre/(rec+pre)

            if crctnum > crcts:
                fvs = fv
                crcts = crctnum
            #print("crct", crctnum)
        if verbose:
            print("CORRECTNESS: ", crcts/150)
            print("F-Value    : ", np.cbrt(fvs))
        return crcts/150, np.cbrt(fvs)
        
cf = ClastaringFusion()
cf.install_val(1, 2)
cf.distance()
cf.show_all_result()

関連記事

bython-chogo.hatenablog.com

bython-chogo.hatenablog.com