Test de la solution de [Ho et al, Denoising Diffusion Probabilistic Models, NeurIPS 2020] sur nuages de points 2D¶

Laurent Risser (CNRS/IMT/ANITI) -- octobre 2024¶

Remarques :

  • On étudie un nuage de points 2D échantillonné suivant un mélange de deux lois normales.
  • Le paramètre $\beta_t$ est constant. $\forall t$, on a alors $\beta_t=\beta$.
  • La variance du bruit $\sigma^2$ est fixée comme étant égale à $\beta_t$.
In [1]:
import numpy as np
import matplotlib.pyplot as plt

1 : generation de donnee synthetiques...¶

In [2]:
def generate_data(n):
    n1=int(3.*n/4.)
    n2=int(n/4.)
    
    X1=np.random.multivariate_normal(mean=np.array([-1,0]), cov=np.array([[0.1,0],[0,0.4]]), size=n1)
    X2=np.random.multivariate_normal(mean=np.array([1,0]), cov=np.array([[0.2,0.15],[0.15,0.2]]), size=n2)

    X=np.concatenate((X1, X2), axis=0)
    np.random.shuffle(X)
    return X
In [3]:
X0=generate_data(200)

plt.scatter(X0[:,0],X0[:,1])
plt.show()
No description has been provided for this image
In [ ]:
 

2 : fonctions pour la diffusion forward¶

2.1 functions definition¶

In [4]:
def sample_q_Xtm1_to_Xt(Xtm1,beta=0.01):
    """
    Input: 2D array x_{t-1}
    Output: 2D array x_{t}
    """
    
    n=Xtm1.shape[0]
    alpha=1-beta
    
    noise=np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1.,0],[0,1.]]), size=n)
    
    Xt=np.sqrt(alpha)*Xtm1+np.sqrt(1-alpha)*noise
    
    return Xt


def sample_q_X0_to_Xt(X0,t,beta=0.01):
    """
    Input: 
      - 2D array x_{t-1}
      - scalar t>1
    Output: 2D array x_{t}
    """
    
    n=X0.shape[0]
    alpha=1-beta
    alpha_bar=np.power(alpha,t)
    
    noise=np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1.,0],[0,1.]]), size=n)
    
    Xt=np.sqrt(alpha_bar)*X0+np.sqrt(1-alpha_bar)*noise  
    
    return Xt
    


def sample_q_X0_to_Xt_with_known_noise(X0,t,noise,beta=0.01):
    """
    Input: 
      - 2D array x_{t-1}
      - scalar t>1
    Output: 2D array x_{t}
    """
    
    n=X0.shape[0]
    alpha=1-beta
    alpha_bar=np.power(alpha,t)
    
    Xt=np.sqrt(alpha_bar)*X0+np.sqrt(1-alpha_bar)*noise  
    
    return Xt
    

2.2 test the incremental forward diffusion¶

In [5]:
Xt=sample_q_Xtm1_to_Xt(X0)

plt.scatter(Xt[:,0],Xt[:,1])
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.title('T='+str(0))
plt.show()

TimeSteps=[1,10,20,30,40]

for i in range(0,len(TimeSteps)-1):
    for j in range(TimeSteps[i],TimeSteps[i+1]):
        Xt=sample_q_Xtm1_to_Xt(Xt)

    plt.scatter(Xt[:,0],Xt[:,1])
    plt.xlim(-3,3)
    plt.ylim(-3,3)
    plt.title('T='+str(TimeSteps[i+1]))
    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

2.3 test the direct forward diffusion¶

In [6]:
TimeStep=20

Xt=sample_q_X0_to_Xt(X0,TimeStep)

plt.scatter(Xt[:,0],Xt[:,1])
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.title('T='+str(TimeStep))
plt.show()
No description has been provided for this image

3 : fonctions pour la diffusion backward¶

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
In [8]:
class MLP_4_2D_point_cloud(nn.Module):
    """
    Basic multi-layer perceptron that takes as input observations in 3 dimensions [X0,X1,t]
    and returns output obsevations in 2 dimensions [X0,X1].
    -> Corresponds to \mu_{\theta} in the [Ho et al 2020] paper
    """
    def __init__(self):
        super(MLP_4_2D_point_cloud, self).__init__()
        self.linear1 = nn.Linear(3,256)
        self.linear2 = nn.Linear(256,256)
        self.linear3 = nn.Linear(256,2)

    def forward(self,X):
        X = F.relu(self.linear1(X))
        X = F.relu(self.linear2(X))
        X = self.linear3(X)
        return X
    
    def make_pred_with_np_data(self,np_Xt,t):
        """
        Input: 
          - numpy array x_{t} -> size=[nb obs,2]
          - scalar t>=1
        Output: 
          - 2D numpy array epsilon_theta(x_{t},t) -> size=[nb obs,2]
        """

        n=np_Xt.shape[0]
        
        pt_input=torch.zeros([n,3])  # n obs with 3 variables
        pt_input[:,0:2]=torch.tensor(np_Xt)
        pt_input[:,2]=t
        
        pt_output=self.forward(pt_input)
        
        return pt_output.detach().numpy()
        
        
        

MLP_epsilon_theta = MLP_4_2D_point_cloud()
print(MLP_epsilon_theta)
MLP_4_2D_point_cloud(
  (linear1): Linear(in_features=3, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=256, bias=True)
  (linear3): Linear(in_features=256, out_features=2, bias=True)
)
In [9]:
def sample_p_Xt_to_Xtm1(Xt,t,Tmax,NN_4_epislon_theta_t,beta=0.01,sigma2=0.01):
    """
    Input: 
      - 2D array x_{t}
      - scalar t>=1
    Output: 2D array x_{t-1}
    """
    alpha=1-beta
    alpha_bar=np.power(alpha,t)
    n=Xt.shape[0]

    if t>1:
        noise=np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1.,0],[0,1.]]), size=n)
    else:
        noise=np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[0.001,0],[0,0.001]]), size=n) #std close to 0 here

    epsilon_theta_t=NN_4_epislon_theta_t.make_pred_with_np_data(Xt,t/Tmax) #WARNING: t is rescaled using Tmax

    mu_theta_t=(1./np.sqrt(alpha))*(Xt - ((1-alpha)/np.sqrt(1-alpha_bar))*epsilon_theta_t)
    
    Xtm1=mu_theta_t+np.sqrt(1-alpha)*noise
    
    return Xtm1


def sample_p_XTmax_to_X0(Tmax,n,NN_4_epislon_theta_t,beta=0.01,sigma2=0.01):
    """
    Input: 
      - Tmax: number of diffusion times -- positive integer
      - n: number of observations -- positive integer
    Output: 
      - generated 2D array x_0
    """

    #sample initial data at t=Tmax
    Xtmax=np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1.,0],[0,1.]]), size=n)
    
    #inverse diffusion process
    Xt=Xtmax.copy()
    for t in range(Tmax,0,-1):
        Xt=sample_p_Xt_to_Xtm1(Xt,t,Tmax,NN_4_epislon_theta_t,beta=beta,sigma2=sigma2)

    #get and return generated data
    X0=Xt.copy()

    return X0
        
                                        
In [10]:
X0=sample_p_XTmax_to_X0(40,200,MLP_epsilon_theta)

directly_generated_data=generate_data(200)

plt.scatter(X0[:,0],X0[:,1],c='b')
plt.scatter(directly_generated_data[:,0],directly_generated_data[:,1],c='r')
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.title('T=0')
plt.show()
No description has been provided for this image

On remarquera ici que le réseau de neurones NN_4_epislon_theta_t n'étant pas encore entrainé, les observations reconstruites (points bleus) ne suivent pas la distribution désirée (points rouges).

4 : Apprentissage d'un MLP_epsilon_theta lié à generate_data¶

In [11]:
#parameters to tune
Total_training_iterations=1500   
Tmax=50
beta=0.01
sigma2=0.01
BATCH_SIZE=128

#instanciate the epsilon_theta
MLP_epsilon_theta = MLP_4_2D_point_cloud()

#initiate the training process
optimizer = torch.optim.Adam(MLP_epsilon_theta.parameters(),lr=0.0001) #, betas=(0.9,0.999))
error = nn.MSELoss()
MLP_epsilon_theta.train()

pt_input_mb=torch.zeros([BATCH_SIZE,3])
pt_output_mb=torch.zeros([BATCH_SIZE,2])

errors_evo=[]

for iteration in range(Total_training_iterations):
    #draw X0 and t 
    X0=generate_data(BATCH_SIZE)
    t=np.random.randint(0,Tmax)

    #draw a noise using N(0,I)
    np_epsilon=np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1.,0],[0,1.]]), size=BATCH_SIZE)
  
    #data diffusion from time 0 to time t
    Xt=sample_q_X0_to_Xt_with_known_noise(X0,t,np_epsilon)
    
    #prepare the input and output mini-batch
    pt_input_mb[:,0:2]=torch.tensor(Xt)
    pt_input_mb[:,2]=(1.*t)/Tmax

    pt_output_mb[:,:]=torch.tensor(np_epsilon)
        
    #prediction
    pt_pred_mb=MLP_epsilon_theta(pt_input_mb)

    #compute and backpropagate the error
    optimizer.zero_grad()

    loss = error(pt_pred_mb, pt_output_mb)
    loss.backward()

    #gradient-descent step on the nn parameters
    optimizer.step()

    #show and store the error
    errors_evo.append(loss.item())
    #print('iteration:'+str(iteration)+' / drawn t='+str(t)+' / error='+str(errors_evo[-1]))

plt.plot(errors_evo)
plt.show()
No description has been provided for this image
In [12]:
X0=sample_p_XTmax_to_X0(Tmax,BATCH_SIZE,MLP_epsilon_theta)


directly_generated_data=generate_data(BATCH_SIZE)

plt.scatter(X0[:,0],X0[:,1],c='b')
plt.scatter(directly_generated_data[:,0],directly_generated_data[:,1],c='r')
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.show()
No description has been provided for this image
In [13]:
X0=sample_p_XTmax_to_X0(Tmax,100*BATCH_SIZE,MLP_epsilon_theta)


directly_generated_data=generate_data(100*BATCH_SIZE)

plt.scatter(X0[:,0],X0[:,1],c='b',alpha=0.01)
plt.scatter(directly_generated_data[:,0],directly_generated_data[:,1],c='r',alpha=0.01)
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.show()
No description has been provided for this image

Le réseau de neurones NN_4_epislon_theta_t étant maintenant entrainé, les observations reconstruites (points bleus) suivent une distribution très proche de celle désirée (points rouges échantillonnés avec notre modèle de référence).

5 : détail d'exemples de processus forward et backward¶

In [14]:
def sample_p_XTmax_to_X0__with_animation(Tmax,n,NN_4_epislon_theta_t,beta=0.01,sigma2=0.01,prefix='',alpha=1):
    """
    Input: 
      - Tmax: number of diffusion times -- positive integer
      - n: number of observations -- positive integer
    Output: 
      - generated 2D array x_0
    """

    fig = plt.figure()

    #sample initial data at t=Tmax
    Xtmax=np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1.,0],[0,1.]]), size=n)
    
    #inverse diffusion process
    Xt=Xtmax.copy()
    plt.scatter(Xt[:,0],Xt[:,1],c='b',alpha=alpha)
    plt.xlim(-3,3)
    plt.ylim(-3,3)
    plt.savefig(prefix+'step'+str(0)+'__t='+str(Tmax))
    plt.show()

    
    for t in range(Tmax,0,-1):
        Xt=sample_p_Xt_to_Xtm1(Xt,t,Tmax,NN_4_epislon_theta_t,beta=beta,sigma2=sigma2)
        plt.scatter(Xt[:,0],Xt[:,1],c='b',alpha=alpha)
        plt.xlim(-3,3)
        plt.ylim(-3,3)
        plt.savefig(prefix+'step'+str(Tmax-t)+'__t='+str(t))
        plt.show()
    
    #get and return generated data
    X0=Xt.copy()

    return X0
In [15]:
X0=sample_p_XTmax_to_X0__with_animation(Tmax,100*BATCH_SIZE,MLP_epsilon_theta,prefix='images2/',alpha=0.01)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [16]:
X0=sample_p_XTmax_to_X0__with_animation(Tmax,BATCH_SIZE,MLP_epsilon_theta,prefix='images1/')
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [17]:
def sample_X0_to_XTmax__with_animation(Tmax,n,beta=0.01,prefix='',alpha=1):
    """
    Input: 
      - Tmax: number of diffusion times -- positive integer
      - n: number of observations -- positive integer
    Output: 
      - generated 2D array x_0
    """

    fig = plt.figure()

    #sample initial data at t=0
    X0=generate_data(n)
    
    Xt=X0.copy()
    
    #diffusion process
    plt.scatter(Xt[:,0],Xt[:,1],c='b',alpha=alpha)
    plt.xlim(-3,3)
    plt.ylim(-3,3)
    plt.savefig(prefix+'step'+str(0)+'__t='+str(0))
    plt.show()

    
    for t in range(1,Tmax):
        Xt=sample_q_Xtm1_to_Xt(Xt,beta=beta)
        plt.scatter(Xt[:,0],Xt[:,1],c='b',alpha=alpha)
        plt.xlim(-3,3)
        plt.ylim(-3,3)
        plt.savefig(prefix+'step'+str(t)+'__t='+str(t))
        plt.show()
    
    return Xt
In [19]:
sample_X0_to_XTmax__with_animation(Tmax,BATCH_SIZE,prefix='./images_forward/',alpha=1)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Out[19]:
array([[-0.85147427, -0.09387519],
       [ 1.13971844, -0.20633607],
       [-1.48284176, -0.53558267],
       [-0.75296759,  0.62479981],
       [ 0.01789809,  1.28611603],
       [-0.05466133, -0.53812781],
       [-0.33232868, -0.28603554],
       [-1.82373845,  0.67409527],
       [ 0.36758342,  0.10874896],
       [-0.77471793,  0.58196685],
       [-0.60190492, -0.54956411],
       [-1.85727634,  0.17032773],
       [-0.0578325 , -0.9878954 ],
       [-0.17450285,  0.40430345],
       [ 0.40344898, -0.00837076],
       [-1.091252  , -0.10442134],
       [-1.56317719,  0.48148117],
       [-0.18500177, -0.2540611 ],
       [ 0.06208434, -0.39449851],
       [-0.99569039,  0.76363145],
       [ 0.57947934,  0.22828075],
       [ 0.05562935, -0.62911406],
       [ 1.26927149, -0.21772577],
       [ 0.34170507, -0.76087633],
       [ 0.95016106, -0.3881387 ],
       [ 0.12210803,  0.44303877],
       [-1.31427313, -0.97777285],
       [ 1.2991974 ,  0.12020855],
       [-2.23657387,  0.59147952],
       [-1.11359017, -0.90424038],
       [-0.41512681, -0.29839213],
       [-0.92078677,  0.56763579],
       [-2.12989868, -0.71615911],
       [-0.82159785,  0.01660516],
       [-0.69276555, -0.19018295],
       [-1.69249854, -0.13051598],
       [-0.68058997,  1.07144639],
       [-0.11424221,  1.33439149],
       [ 1.34174495, -0.98213891],
       [-0.20110114,  0.38230831],
       [-0.68710954, -0.55204149],
       [-1.65030118,  0.61701966],
       [-1.98848894,  0.08355589],
       [-0.20459747, -0.92601755],
       [ 0.39758717, -0.26475847],
       [ 1.06480286,  0.10526594],
       [-0.314905  , -1.31597948],
       [-1.96340309,  0.90430111],
       [ 1.1536414 ,  0.51136359],
       [-0.06497288,  0.29376254],
       [-0.14139664,  0.60756543],
       [-1.09584318, -1.67800273],
       [ 1.47033726,  0.21191616],
       [-0.38430325,  0.35948441],
       [ 0.25566555,  0.27943414],
       [-0.55924221,  1.12519549],
       [-1.03595881, -0.05218854],
       [-1.21454181,  1.23189122],
       [-1.12570938,  0.11599385],
       [-0.88494642,  1.29331838],
       [-1.23832515,  0.99618174],
       [-0.60413421,  0.12242985],
       [-1.07352414, -0.65251214],
       [-0.67993477,  0.1612385 ],
       [ 1.23796449,  1.77561084],
       [-0.98290174,  1.58599862],
       [-0.92569911, -0.24563209],
       [ 0.090616  ,  1.38028518],
       [ 0.67041988, -1.39277994],
       [-0.33913945, -0.29143047],
       [ 0.79914818, -0.45448646],
       [-1.50850347, -0.04469276],
       [ 2.11997226,  0.04383052],
       [-1.31000736,  0.02873652],
       [ 1.93763541,  1.08248287],
       [-1.23462289,  0.13272798],
       [ 0.68371433, -0.8039759 ],
       [-2.00176908, -0.8448913 ],
       [-2.60174759,  0.84957566],
       [-0.71315003,  0.42414229],
       [ 0.61097268, -1.81665991],
       [-0.4348249 , -0.51689862],
       [ 2.10055332, -0.3792362 ],
       [ 1.51494502,  0.30590504],
       [-1.60638061,  0.92873719],
       [ 0.36721403, -0.03756551],
       [-1.43713972,  0.11802398],
       [ 0.70293249, -0.36884881],
       [ 0.73003545,  0.04979899],
       [-0.81173352, -0.4858336 ],
       [-1.24168634, -0.63740752],
       [-1.05407629, -0.7669768 ],
       [-1.11174835, -0.0274489 ],
       [-0.87042204, -1.24685731],
       [ 1.16833645,  1.01539994],
       [-0.44621083,  0.07620836],
       [-0.16385715,  0.02236124],
       [-0.69762118,  0.61240797],
       [ 1.33597591,  0.51124805],
       [ 0.56994027,  0.11696339],
       [-0.3763391 , -0.33575284],
       [-0.06247219, -0.09456224],
       [-0.3953256 , -0.64967413],
       [-1.12693242, -1.18416111],
       [ 0.64891603, -0.18282929],
       [-0.86939397,  1.65015912],
       [ 0.85147841,  0.04257491],
       [-0.05333343, -1.86537853],
       [ 0.62115367,  0.36175966],
       [ 0.69693241,  0.83027949],
       [-1.48822298, -0.00811728],
       [-0.55115612, -0.34937586],
       [-0.72188518,  0.18048364],
       [ 1.12457473,  0.35885361],
       [-1.95007863, -0.59047873],
       [-2.32918418,  0.29724958],
       [-0.5043029 ,  0.2077733 ],
       [-0.50883539,  0.44714662],
       [ 0.18217484,  1.23813854],
       [ 0.22871856,  0.2341497 ],
       [ 0.83111286,  0.96868426],
       [-0.320984  , -0.95442558],
       [-0.50165268, -1.04721232],
       [-0.33495598, -0.26757387],
       [-1.37500561,  2.10443044],
       [-1.5215937 , -0.25409491],
       [-0.67141206, -1.10821087],
       [-0.55692806,  0.8830584 ]])
In [ ]: