ihit's diary

ちょっとしたメモに

K-means

とある事情でK-meansを実装してみた。

大まかな流れとしては
1. 適当にK個のクラスに割り振る
2. それぞれのクラスの重心を代表点とする
3. それぞれの点に最も近い代表点を割り出しクラス分けする
4. 2と3を代表点が動かなくなるまでやる

以下ソース

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import numpy as np
import cv2
import matplotlib.pyplot as plt

#代表点の個数
K = 3

#画像のサイズ
LEN = 512
LENGTH = LEN*LEN

#RGB
DIM = 3

#LENGTHのRGB画像
im = cv2.imread('lena.bmp',flags=1)
if im is None:
	print "None"
	exit()
im.resize((LENGTH,DIM))

#代表点と画素値からそれぞれのクラスタ番号を得る
def get_len(repre,x):
	length = sys.maxint #代表点までの距離
	for i in range(0,len(repre)):
		if np.linalg.norm(repre[i]-x) < length:
			length = np.linalg.norm(repre[i]-x)
			num = i
	return num

#重心を計算し、新たな代表点を計算
def get_new(im,num):
	repre = np.zeros(K*DIM).reshape((K,DIM))
	count = np.zeros(K)
	for i in range(0,LENGTH):
			repre[num[i]] += im[i]
			count[num[i]] += 1
	for i in range(0,K):
		if count[i]!= 0:
			repre[i] /= count[i]
		else:
			pass
	return repre

#代表点の誤差を計算
def get_error(repre,new_repre):
	error = 0
	for i in range(0,len(repre)):
		error += np.linalg.norm(repre[i]-new_repre[i])
	return error

#初期クラス
num = np.zeros(LENGTH)
for i in range(0,LENGTH):
		num[i] = np.random.randint(0,K)

#現在代表点と誤差の初期化
repre = np.zeros(K*DIM).reshape((K,DIM))
new_repre = np.zeros(K*DIM).reshape((K,DIM))
error = get_error(repre,new_repre)

#k-meansの処理
while 1:
	#クラスタ番号更新
	repre = get_new(im,num)
	error = get_error(repre,new_repre)
	new_repre = repre
	print "error=",error
	if error < 0.1:
		break
	for i in range(0,LENGTH):
			num[i] = get_len(repre,im[i])

#符号なし8ビット整数を使う
im_kmean = np.zeros(LENGTH*DIM, dtype=np.uint8).reshape((LENGTH,DIM))
for i in range(0,LENGTH):
		im_kmean[i][:] = repre[num[i]]

im_kmean.resize((LEN,LEN,DIM))
cv2.imshow("test",im_kmean)
cv2.imwrite('result.bmp', im_kmean)
cv2.waitKey(0)
cv2.destroyAllWindows()

結果は
入力画像
f:id:ihit:20140404023745j:plain
出力画像(3色に減色)
f:id:ihit:20140404023801j:plain

まあちゃんと動いているみたいだけど遅い…
手順3のそれぞれの点から一番近い代表点を割り出す所がとにかく遅い。
これを解消するには木構造を使用したりすればいいらしい。
他にも早くするために最初の代表点を上手く選んでやるなんて方法もあるとのこと。