class MNMASH:
def __init__(self, X=None, Y=None, Z=None, B=None, S=None, V=None):
self.mash = MASH(X=X,Y=Y,Z=Z,B=B,S=S,V=np.cov(Y, rowvar=False) if V is None else V)
self.Y = Y
# variational parameter for the one-nonzero effect model for each l and j
self.alpha0 = None
# posterior mean on \beta_lj
self.mu0 = None
self.Xr0 = np.zeros((self.Y.shape[0], self.Y.shape[1]))
self.elbo = []
self.post_mean_mat = None
def set_prior(self, U, grid, pi):
self.mash.set_prior(U, grid, pi)
def fit(self, niter=100, L=5, calc_elbo=False):
self.alpha0 = np.zeros((L, self.mash.X.shape[1]))
self.mu0 = np.zeros((L, self.mash.X.shape[1], self.Y.shape[1]))
for i in range(niter):
self._calc_update()
if calc_elbo:
self._calc_elbo()
self._calc_posterior()
def _calc_update(self):
for l in range(self.alpha0.shape[0]):
self.Xr0 -= self.mash.X @ (np.vstack(self.alpha0[l,:]) * self.mu0[l,:,:])
self.alpha0[l,:], self.mu0[l,:,:] = self._calc_single_snp(self.Y - self.Xr0)
self.Xr0 += self.mash.X @ (np.vstack(self.alpha0[l,:]) * self.mu0[l,:,:])
def _calc_single_snp(self, R):
self.mash.reset({'Y': R})
self.mash.get_summary_stats()
self.mash.fit()
bf = np.exp(self.mash.l10bf)
return bf/np.sum(bf), self.mash.post_mean_mat.T
def _calc_elbo(self):
pass
def _calc_posterior(self):
almu = np.zeros((self.mu0.shape[0], self.mu0.shape[1], self.mu0.shape[2]))
for l in range(self.alpha0.shape[0]):
almu[l,:,:] = np.vstack(self.alpha0[l,:]) * self.mu0[l,:,:]
self.post_mean_mat = np.sum(almu, axis = 0)