Test de la solution de [Ho et al, Denoising Diffusion Probabilistic Models, NeurIPS 2020] sur nuages de points 2D¶
Laurent Risser (CNRS/IMT/ANITI) -- octobre 2024¶
Remarques :
- On étudie un nuage de points 2D échantillonné suivant un mélange de deux lois normales.
- Le paramètre $\beta_t$ est constant. $\forall t$, on a alors $\beta_t=\beta$.
- La variance du bruit $\sigma^2$ est fixée comme étant égale à $\beta_t$.
In [1]:
import numpy as np
import matplotlib.pyplot as plt
1 : generation de donnee synthetiques...¶
In [2]:
def generate_data(n):
n1=int(3.*n/4.)
n2=int(n/4.)
X1=np.random.multivariate_normal(mean=np.array([-1,0]), cov=np.array([[0.1,0],[0,0.4]]), size=n1)
X2=np.random.multivariate_normal(mean=np.array([1,0]), cov=np.array([[0.2,0.15],[0.15,0.2]]), size=n2)
X=np.concatenate((X1, X2), axis=0)
np.random.shuffle(X)
return X
In [3]:
X0=generate_data(200)
plt.scatter(X0[:,0],X0[:,1])
plt.show()
In [ ]:
2 : fonctions pour la diffusion forward¶
2.1 functions definition¶
In [4]:
def sample_q_Xtm1_to_Xt(Xtm1,beta=0.01):
"""
Input: 2D array x_{t-1}
Output: 2D array x_{t}
"""
n=Xtm1.shape[0]
alpha=1-beta
noise=np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1.,0],[0,1.]]), size=n)
Xt=np.sqrt(alpha)*Xtm1+np.sqrt(1-alpha)*noise
return Xt
def sample_q_X0_to_Xt(X0,t,beta=0.01):
"""
Input:
- 2D array x_{t-1}
- scalar t>1
Output: 2D array x_{t}
"""
n=X0.shape[0]
alpha=1-beta
alpha_bar=np.power(alpha,t)
noise=np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1.,0],[0,1.]]), size=n)
Xt=np.sqrt(alpha_bar)*X0+np.sqrt(1-alpha_bar)*noise
return Xt
def sample_q_X0_to_Xt_with_known_noise(X0,t,noise,beta=0.01):
"""
Input:
- 2D array x_{t-1}
- scalar t>1
Output: 2D array x_{t}
"""
n=X0.shape[0]
alpha=1-beta
alpha_bar=np.power(alpha,t)
Xt=np.sqrt(alpha_bar)*X0+np.sqrt(1-alpha_bar)*noise
return Xt
2.2 test the incremental forward diffusion¶
In [5]:
Xt=sample_q_Xtm1_to_Xt(X0)
plt.scatter(Xt[:,0],Xt[:,1])
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.title('T='+str(0))
plt.show()
TimeSteps=[1,10,20,30,40]
for i in range(0,len(TimeSteps)-1):
for j in range(TimeSteps[i],TimeSteps[i+1]):
Xt=sample_q_Xtm1_to_Xt(Xt)
plt.scatter(Xt[:,0],Xt[:,1])
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.title('T='+str(TimeSteps[i+1]))
plt.show()
2.3 test the direct forward diffusion¶
In [6]:
TimeStep=20
Xt=sample_q_X0_to_Xt(X0,TimeStep)
plt.scatter(Xt[:,0],Xt[:,1])
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.title('T='+str(TimeStep))
plt.show()
3 : fonctions pour la diffusion backward¶
In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
In [8]:
class MLP_4_2D_point_cloud(nn.Module):
"""
Basic multi-layer perceptron that takes as input observations in 3 dimensions [X0,X1,t]
and returns output obsevations in 2 dimensions [X0,X1].
-> Corresponds to \mu_{\theta} in the [Ho et al 2020] paper
"""
def __init__(self):
super(MLP_4_2D_point_cloud, self).__init__()
self.linear1 = nn.Linear(3,256)
self.linear2 = nn.Linear(256,256)
self.linear3 = nn.Linear(256,2)
def forward(self,X):
X = F.relu(self.linear1(X))
X = F.relu(self.linear2(X))
X = self.linear3(X)
return X
def make_pred_with_np_data(self,np_Xt,t):
"""
Input:
- numpy array x_{t} -> size=[nb obs,2]
- scalar t>=1
Output:
- 2D numpy array epsilon_theta(x_{t},t) -> size=[nb obs,2]
"""
n=np_Xt.shape[0]
pt_input=torch.zeros([n,3]) # n obs with 3 variables
pt_input[:,0:2]=torch.tensor(np_Xt)
pt_input[:,2]=t
pt_output=self.forward(pt_input)
return pt_output.detach().numpy()
MLP_epsilon_theta = MLP_4_2D_point_cloud()
print(MLP_epsilon_theta)
MLP_4_2D_point_cloud( (linear1): Linear(in_features=3, out_features=256, bias=True) (linear2): Linear(in_features=256, out_features=256, bias=True) (linear3): Linear(in_features=256, out_features=2, bias=True) )
In [9]:
def sample_p_Xt_to_Xtm1(Xt,t,Tmax,NN_4_epislon_theta_t,beta=0.01,sigma2=0.01):
"""
Input:
- 2D array x_{t}
- scalar t>=1
Output: 2D array x_{t-1}
"""
alpha=1-beta
alpha_bar=np.power(alpha,t)
n=Xt.shape[0]
if t>1:
noise=np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1.,0],[0,1.]]), size=n)
else:
noise=np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[0.001,0],[0,0.001]]), size=n) #std close to 0 here
epsilon_theta_t=NN_4_epislon_theta_t.make_pred_with_np_data(Xt,t/Tmax) #WARNING: t is rescaled using Tmax
mu_theta_t=(1./np.sqrt(alpha))*(Xt - ((1-alpha)/np.sqrt(1-alpha_bar))*epsilon_theta_t)
Xtm1=mu_theta_t+np.sqrt(1-alpha)*noise
return Xtm1
def sample_p_XTmax_to_X0(Tmax,n,NN_4_epislon_theta_t,beta=0.01,sigma2=0.01):
"""
Input:
- Tmax: number of diffusion times -- positive integer
- n: number of observations -- positive integer
Output:
- generated 2D array x_0
"""
#sample initial data at t=Tmax
Xtmax=np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1.,0],[0,1.]]), size=n)
#inverse diffusion process
Xt=Xtmax.copy()
for t in range(Tmax,0,-1):
Xt=sample_p_Xt_to_Xtm1(Xt,t,Tmax,NN_4_epislon_theta_t,beta=beta,sigma2=sigma2)
#get and return generated data
X0=Xt.copy()
return X0
In [10]:
X0=sample_p_XTmax_to_X0(40,200,MLP_epsilon_theta)
directly_generated_data=generate_data(200)
plt.scatter(X0[:,0],X0[:,1],c='b')
plt.scatter(directly_generated_data[:,0],directly_generated_data[:,1],c='r')
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.title('T=0')
plt.show()
On remarquera ici que le réseau de neurones NN_4_epislon_theta_t n'étant pas encore entrainé, les observations reconstruites (points bleus) ne suivent pas la distribution désirée (points rouges).
4 : Apprentissage d'un MLP_epsilon_theta lié à generate_data¶
In [11]:
#parameters to tune
Total_training_iterations=1500
Tmax=50
beta=0.01
sigma2=0.01
BATCH_SIZE=128
#instanciate the epsilon_theta
MLP_epsilon_theta = MLP_4_2D_point_cloud()
#initiate the training process
optimizer = torch.optim.Adam(MLP_epsilon_theta.parameters(),lr=0.0001) #, betas=(0.9,0.999))
error = nn.MSELoss()
MLP_epsilon_theta.train()
pt_input_mb=torch.zeros([BATCH_SIZE,3])
pt_output_mb=torch.zeros([BATCH_SIZE,2])
errors_evo=[]
for iteration in range(Total_training_iterations):
#draw X0 and t
X0=generate_data(BATCH_SIZE)
t=np.random.randint(0,Tmax)
#draw a noise using N(0,I)
np_epsilon=np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1.,0],[0,1.]]), size=BATCH_SIZE)
#data diffusion from time 0 to time t
Xt=sample_q_X0_to_Xt_with_known_noise(X0,t,np_epsilon)
#prepare the input and output mini-batch
pt_input_mb[:,0:2]=torch.tensor(Xt)
pt_input_mb[:,2]=(1.*t)/Tmax
pt_output_mb[:,:]=torch.tensor(np_epsilon)
#prediction
pt_pred_mb=MLP_epsilon_theta(pt_input_mb)
#compute and backpropagate the error
optimizer.zero_grad()
loss = error(pt_pred_mb, pt_output_mb)
loss.backward()
#gradient-descent step on the nn parameters
optimizer.step()
#show and store the error
errors_evo.append(loss.item())
#print('iteration:'+str(iteration)+' / drawn t='+str(t)+' / error='+str(errors_evo[-1]))
plt.plot(errors_evo)
plt.show()
In [12]:
X0=sample_p_XTmax_to_X0(Tmax,BATCH_SIZE,MLP_epsilon_theta)
directly_generated_data=generate_data(BATCH_SIZE)
plt.scatter(X0[:,0],X0[:,1],c='b')
plt.scatter(directly_generated_data[:,0],directly_generated_data[:,1],c='r')
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.show()
In [13]:
X0=sample_p_XTmax_to_X0(Tmax,100*BATCH_SIZE,MLP_epsilon_theta)
directly_generated_data=generate_data(100*BATCH_SIZE)
plt.scatter(X0[:,0],X0[:,1],c='b',alpha=0.01)
plt.scatter(directly_generated_data[:,0],directly_generated_data[:,1],c='r',alpha=0.01)
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.show()
Le réseau de neurones NN_4_epislon_theta_t étant maintenant entrainé, les observations reconstruites (points bleus) suivent une distribution très proche de celle désirée (points rouges échantillonnés avec notre modèle de référence).
5 : détail d'exemples de processus forward et backward¶
In [14]:
def sample_p_XTmax_to_X0__with_animation(Tmax,n,NN_4_epislon_theta_t,beta=0.01,sigma2=0.01,prefix='',alpha=1):
"""
Input:
- Tmax: number of diffusion times -- positive integer
- n: number of observations -- positive integer
Output:
- generated 2D array x_0
"""
fig = plt.figure()
#sample initial data at t=Tmax
Xtmax=np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1.,0],[0,1.]]), size=n)
#inverse diffusion process
Xt=Xtmax.copy()
plt.scatter(Xt[:,0],Xt[:,1],c='b',alpha=alpha)
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.savefig(prefix+'step'+str(0)+'__t='+str(Tmax))
plt.show()
for t in range(Tmax,0,-1):
Xt=sample_p_Xt_to_Xtm1(Xt,t,Tmax,NN_4_epislon_theta_t,beta=beta,sigma2=sigma2)
plt.scatter(Xt[:,0],Xt[:,1],c='b',alpha=alpha)
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.savefig(prefix+'step'+str(Tmax-t)+'__t='+str(t))
plt.show()
#get and return generated data
X0=Xt.copy()
return X0
In [15]:
X0=sample_p_XTmax_to_X0__with_animation(Tmax,100*BATCH_SIZE,MLP_epsilon_theta,prefix='images2/',alpha=0.01)
In [16]:
X0=sample_p_XTmax_to_X0__with_animation(Tmax,BATCH_SIZE,MLP_epsilon_theta,prefix='images1/')
In [17]:
def sample_X0_to_XTmax__with_animation(Tmax,n,beta=0.01,prefix='',alpha=1):
"""
Input:
- Tmax: number of diffusion times -- positive integer
- n: number of observations -- positive integer
Output:
- generated 2D array x_0
"""
fig = plt.figure()
#sample initial data at t=0
X0=generate_data(n)
Xt=X0.copy()
#diffusion process
plt.scatter(Xt[:,0],Xt[:,1],c='b',alpha=alpha)
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.savefig(prefix+'step'+str(0)+'__t='+str(0))
plt.show()
for t in range(1,Tmax):
Xt=sample_q_Xtm1_to_Xt(Xt,beta=beta)
plt.scatter(Xt[:,0],Xt[:,1],c='b',alpha=alpha)
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.savefig(prefix+'step'+str(t)+'__t='+str(t))
plt.show()
return Xt
In [19]:
sample_X0_to_XTmax__with_animation(Tmax,BATCH_SIZE,prefix='./images_forward/',alpha=1)
Out[19]:
array([[-0.85147427, -0.09387519], [ 1.13971844, -0.20633607], [-1.48284176, -0.53558267], [-0.75296759, 0.62479981], [ 0.01789809, 1.28611603], [-0.05466133, -0.53812781], [-0.33232868, -0.28603554], [-1.82373845, 0.67409527], [ 0.36758342, 0.10874896], [-0.77471793, 0.58196685], [-0.60190492, -0.54956411], [-1.85727634, 0.17032773], [-0.0578325 , -0.9878954 ], [-0.17450285, 0.40430345], [ 0.40344898, -0.00837076], [-1.091252 , -0.10442134], [-1.56317719, 0.48148117], [-0.18500177, -0.2540611 ], [ 0.06208434, -0.39449851], [-0.99569039, 0.76363145], [ 0.57947934, 0.22828075], [ 0.05562935, -0.62911406], [ 1.26927149, -0.21772577], [ 0.34170507, -0.76087633], [ 0.95016106, -0.3881387 ], [ 0.12210803, 0.44303877], [-1.31427313, -0.97777285], [ 1.2991974 , 0.12020855], [-2.23657387, 0.59147952], [-1.11359017, -0.90424038], [-0.41512681, -0.29839213], [-0.92078677, 0.56763579], [-2.12989868, -0.71615911], [-0.82159785, 0.01660516], [-0.69276555, -0.19018295], [-1.69249854, -0.13051598], [-0.68058997, 1.07144639], [-0.11424221, 1.33439149], [ 1.34174495, -0.98213891], [-0.20110114, 0.38230831], [-0.68710954, -0.55204149], [-1.65030118, 0.61701966], [-1.98848894, 0.08355589], [-0.20459747, -0.92601755], [ 0.39758717, -0.26475847], [ 1.06480286, 0.10526594], [-0.314905 , -1.31597948], [-1.96340309, 0.90430111], [ 1.1536414 , 0.51136359], [-0.06497288, 0.29376254], [-0.14139664, 0.60756543], [-1.09584318, -1.67800273], [ 1.47033726, 0.21191616], [-0.38430325, 0.35948441], [ 0.25566555, 0.27943414], [-0.55924221, 1.12519549], [-1.03595881, -0.05218854], [-1.21454181, 1.23189122], [-1.12570938, 0.11599385], [-0.88494642, 1.29331838], [-1.23832515, 0.99618174], [-0.60413421, 0.12242985], [-1.07352414, -0.65251214], [-0.67993477, 0.1612385 ], [ 1.23796449, 1.77561084], [-0.98290174, 1.58599862], [-0.92569911, -0.24563209], [ 0.090616 , 1.38028518], [ 0.67041988, -1.39277994], [-0.33913945, -0.29143047], [ 0.79914818, -0.45448646], [-1.50850347, -0.04469276], [ 2.11997226, 0.04383052], [-1.31000736, 0.02873652], [ 1.93763541, 1.08248287], [-1.23462289, 0.13272798], [ 0.68371433, -0.8039759 ], [-2.00176908, -0.8448913 ], [-2.60174759, 0.84957566], [-0.71315003, 0.42414229], [ 0.61097268, -1.81665991], [-0.4348249 , -0.51689862], [ 2.10055332, -0.3792362 ], [ 1.51494502, 0.30590504], [-1.60638061, 0.92873719], [ 0.36721403, -0.03756551], [-1.43713972, 0.11802398], [ 0.70293249, -0.36884881], [ 0.73003545, 0.04979899], [-0.81173352, -0.4858336 ], [-1.24168634, -0.63740752], [-1.05407629, -0.7669768 ], [-1.11174835, -0.0274489 ], [-0.87042204, -1.24685731], [ 1.16833645, 1.01539994], [-0.44621083, 0.07620836], [-0.16385715, 0.02236124], [-0.69762118, 0.61240797], [ 1.33597591, 0.51124805], [ 0.56994027, 0.11696339], [-0.3763391 , -0.33575284], [-0.06247219, -0.09456224], [-0.3953256 , -0.64967413], [-1.12693242, -1.18416111], [ 0.64891603, -0.18282929], [-0.86939397, 1.65015912], [ 0.85147841, 0.04257491], [-0.05333343, -1.86537853], [ 0.62115367, 0.36175966], [ 0.69693241, 0.83027949], [-1.48822298, -0.00811728], [-0.55115612, -0.34937586], [-0.72188518, 0.18048364], [ 1.12457473, 0.35885361], [-1.95007863, -0.59047873], [-2.32918418, 0.29724958], [-0.5043029 , 0.2077733 ], [-0.50883539, 0.44714662], [ 0.18217484, 1.23813854], [ 0.22871856, 0.2341497 ], [ 0.83111286, 0.96868426], [-0.320984 , -0.95442558], [-0.50165268, -1.04721232], [-0.33495598, -0.26757387], [-1.37500561, 2.10443044], [-1.5215937 , -0.25409491], [-0.67141206, -1.10821087], [-0.55692806, 0.8830584 ]])
In [ ]: