無限混合正規分布 cont.

前回の補足。

(たぶん)実装可能なメモ書きがこちら

psuedoコード代わりのRコード。
データを入れれば、一応動くけれど、ナイーブな実装なので超遅い(Rでも)。

# 1. prepare dataset
X = ...

# 2. set up hyperparameters
theta0 = list(NULL)
theta0$m = c(0,0)
theta0$xi = 1e-2
theta0$eta = 2*D+2+1e-1
theta0$B = 1e-1*diag(D)*1

# 3. Gibbs sampling
DPM = function(X, theta0){

	# 3.1 functions

	# logarithm of multivariate gamma function
	ln_mgamma = function(x, d){
		0.25 * d * (d - 1) * log(pi) + sum(lgamma(x + 0.5 * (1 - 1:d)))
	}

	# calculate posterior of NIW
	posterior = function(X, theta0){

		if(class(X) == "numeric"){
			X = t(X)
		}

		N = nrow(X)
		D = ncol(X)

		theta = list(NULL)
		theta$xi = theta0$xi + N
		theta$m = (theta0$xi * theta0$m + apply(X, 2, sum)) / theta$xi
		theta$eta = theta0$eta + N
		theta$B = theta0$B + theta0$xi * (theta0$m %*% t(theta0$m)) + t(X) %*% X - theta$xi * (theta$m %*% t(theta$m))

		return(theta)

	}

	# logarithm of the gaussian marginal likelihood for NIW parameter distribution
	ln_marginal_likelihood = function(X, theta0){

		if(class(X) == "numeric"){
			X = t(X)
		}

		N = nrow(X)
		D = ncol(X)

		theta = posterior(X, theta0)

		a = 0.5 * (theta0$eta - D - 1)
		ap = 0.5 * (theta$eta - D - 1)
	
		retval = -0.5*D*N*log(pi) + 0.5*D*log(theta0$xi/theta$xi) + ln_mgamma(ap, D) - ln_mgamma(a, D) + 
				a * determinant(theta0$B)$modulus - ap * determinant(theta$B)$modulus

		return(retval)

	}

	# 3-2. initialize MCMC state & parameters
	N = nrow(X)
	D = ncol(X)
	Z = numeric(N)
	K = 0
	alpha = 1e-1

	itmax = 500

	# 3-3. execute gibbs sampling
	for(it in 1:itmax){

		# randomize samples
		idx = sample(1:N, N)

		# sample z for each x
		for(n in idx){

			Z[n] = 0
			x = t(X[n,])

			lnLn = numeric(K+1)
			if(K > 0){
				for(k in 1:K){
					dat_idx = which(Z == k)
					Nz = length(dat_idx)
					if(Nz > 0){
						lnLn[k] = log(Nz) + ln_marginal_likelihood(x, posterior(X[dat_idx,], theta0))
					}else{
						lnLn[k] = -Inf
					}
				}
			}
			lnLn[K+1] = log(alpha) + ln_marginal_likelihood(x, theta0)
			
			Z[n] = sample(1:(K+1), 1, prob=exp(lnLn - max(lnLn)))

			if(Z[n] == (K+1)){
				K = K + 1
			}

		}

	}	
}

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

CAPTCHA