目錄

1140225 meeting

$SSSD^{S4}$ 模型探討

原作者 Github https://github.com/AI4HealthUOL/SSSD

使用 \SSSD-main\src\config\config_SSSDS4.json 作為設定檔案。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
{   
    "diffusion_config":{
        "T": 200,
        "beta_0": 0.0001,
        "beta_T": 0.02
    },
    "wavenet_config": {
        "in_channels": 14, 
        "out_channels":14,
        "num_res_layers": 36,
        "res_channels": 256, 
        "skip_channels": 256,
        "diffusion_step_embed_dim_in": 128,
        "diffusion_step_embed_dim_mid": 512,
        "diffusion_step_embed_dim_out": 512,
        "s4_lmax": 100,
        "s4_d_state":64,
        "s4_dropout":0.0,
        "s4_bidirectional":1,
        "s4_layernorm":1
    },
    "train_config": {
        "output_directory": "/mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/results/mujoco/90/",
        "ckpt_iter": "max",
        "iters_per_ckpt": 100,
        "iters_per_logging": 100,
        "n_iters": 1000,
        "learning_rate": 2e-4,
        "only_generate_missing": 1,
        "use_model": 2,
        "masking": "rm",
        "missing_k": 90
    },
    "trainset_config":{
        "train_data_path": "/mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/datasets/Mujoco/train_mujoco.npy",
        "test_data_path": "/mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/datasets/Mujoco/test_mujoco.npy",
        "segment_length":100,
        "sampling_rate": 100
    },
    "gen_config":{
        "output_directory": "/mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/results/mujoco/90/",
        "ckpt_path": "/mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/results/mujoco/90/"
    }
}

執行 \SSSD-main\src\train.py 進行模型訓練。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import argparse
import json
import numpy as np
import torch
import torch.nn as nn

from utils.util import find_max_epoch, print_size, training_loss, calc_diffusion_hyperparams
from utils.util import get_mask_mnr, get_mask_bm, get_mask_rm

from imputers.DiffWaveImputer import DiffWaveImputer
from imputers.SSSDSAImputer import SSSDSAImputer
from imputers.SSSDS4Imputer import SSSDS4Imputer


def train(output_directory,
          ckpt_iter,
          n_iters,
          iters_per_ckpt,
          iters_per_logging,
          learning_rate,
          use_model,
          only_generate_missing,
          masking,
          missing_k):
  
    """
    Train Diffusion Models

    Parameters:
    output_directory (str):         save model checkpoints to this path
    ckpt_iter (int or 'max'):       the pretrained checkpoint to be loaded; 
                                    automatically selects the maximum iteration if 'max' is selected
    data_path (str):                path to dataset, numpy array.
    n_iters (int):                  number of iterations to train
    iters_per_ckpt (int):           number of iterations to save checkpoint, 
                                    default is 10k, for models with residual_channel=64 this number can be larger
    iters_per_logging (int):        number of iterations to save training log and compute validation loss, default is 100
    learning_rate (float):          learning rate

    use_model (int):                0:DiffWave. 1:SSSDSA. 2:SSSDS4.
    only_generate_missing (int):    0:all sample diffusion.  1:only apply diffusion to missing portions of the signal
    masking(str):                   'mnr': missing not at random, 'bm': blackout missing, 'rm': random missing
    missing_k (int):                k missing time steps for each feature across the sample length.
    """

    # generate experiment (local) path
    local_path = "T{}_beta0{}_betaT{}".format(diffusion_config["T"],
                                              diffusion_config["beta_0"],
                                              diffusion_config["beta_T"])

    # Get shared output_directory ready
    output_directory = os.path.join(output_directory, local_path)
    if not os.path.isdir(output_directory):
        os.makedirs(output_directory)
        os.chmod(output_directory, 0o775)
    print("output directory", output_directory, flush=True)

    # map diffusion hyperparameters to gpu
    for key in diffusion_hyperparams:
        if key != "T":
            diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda()

    # predefine model
    if use_model == 0:
        net = DiffWaveImputer(**model_config).cuda()
    elif use_model == 1:
        net = SSSDSAImputer(**model_config).cuda()
    elif use_model == 2:
        net = SSSDS4Imputer(**model_config).cuda()
    else:
        print('Model chosen not available.')
    print_size(net)

    # define optimizer
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    # load checkpoint
    if ckpt_iter == 'max':
        ckpt_iter = find_max_epoch(output_directory)
    if ckpt_iter >= 0:
        try:
            # load checkpoint file
            model_path = os.path.join(output_directory, '{}.pkl'.format(ckpt_iter))
            checkpoint = torch.load(model_path, map_location='cpu')

            # feed model dict and optimizer state
            net.load_state_dict(checkpoint['model_state_dict'])
            if 'optimizer_state_dict' in checkpoint:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

            print('Successfully loaded model at iteration {}'.format(ckpt_iter))
        except:
            ckpt_iter = -1
            print('No valid checkpoint model found, start training from initialization try.')
    else:
        ckpt_iter = -1
        print('No valid checkpoint model found, start training from initialization.')

  
  
  
    ### Custom data loading and reshaping ###
  
  

    training_data = np.load(trainset_config['train_data_path'])
    training_data = np.split(training_data, 160, 0)
    training_data = np.array(training_data)
    training_data = torch.from_numpy(training_data).float().cuda()
    print('Data loaded')

  
  
    # training
    n_iter = ckpt_iter + 1
    while n_iter < n_iters + 1:
        for batch in training_data:

            if masking == 'rm':
                transposed_mask = get_mask_rm(batch[0], missing_k)
            elif masking == 'mnr':
                transposed_mask = get_mask_mnr(batch[0], missing_k)
            elif masking == 'bm':
                transposed_mask = get_mask_bm(batch[0], missing_k)

            mask = transposed_mask.permute(1, 0)
            mask = mask.repeat(batch.size()[0], 1, 1).float().cuda()
            loss_mask = ~mask.bool()
            batch = batch.permute(0, 2, 1)

            assert batch.size() == mask.size() == loss_mask.size()

            # back-propagation
            optimizer.zero_grad()
            X = batch, batch, mask, loss_mask
            loss = training_loss(net, nn.MSELoss(), X, diffusion_hyperparams,
                                 only_generate_missing=only_generate_missing)

            loss.backward()
            optimizer.step()

            if n_iter % iters_per_logging == 0:
                print("iteration: {} \tloss: {}".format(n_iter, loss.item()))

            # save checkpoint
            if n_iter > 0 and n_iter % iters_per_ckpt == 0:
                checkpoint_name = '{}.pkl'.format(n_iter)
                torch.save({'model_state_dict': net.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict()},
                           os.path.join(output_directory, checkpoint_name))
                print('model at iteration %s is saved' % n_iter)

            n_iter += 1

parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='/mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/config/config_SSSDS4.json',
                    help='JSON file for configuration')

args = parser.parse_args()

with open(args.config) as f:
    data = f.read()

config = json.loads(data)
print(config)

train_config = config["train_config"]  # training parameters

global trainset_config
trainset_config = config["trainset_config"]  # to load trainset

global diffusion_config
diffusion_config = config["diffusion_config"]  # basic hyperparameters

global diffusion_hyperparams
diffusion_hyperparams = calc_diffusion_hyperparams(
    **diffusion_config)  # dictionary of all diffusion hyperparameters

global model_config
if train_config['use_model'] == 0:
    model_config = config['wavenet_config']
elif train_config['use_model'] == 1:
    model_config = config['sashimi_config']
elif train_config['use_model'] == 2:
    model_config = config['wavenet_config']

#train(**train_config)
train(output_directory = train_config['output_directory'],
          ckpt_iter = train_config['ckpt_iter'],
          n_iters = train_config['n_iters'],
          iters_per_ckpt = train_config['iters_per_ckpt'],
          iters_per_logging = train_config['iters_per_logging'],
          learning_rate = train_config['learning_rate'],
          use_model = train_config['use_model'],
          only_generate_missing = train_config['only_generate_missing'],
          masking = train_config['masking'],
          missing_k = train_config['missing_k'])
執行結果參考
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
output directory /mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/results/mujoco/90/T200_beta00.0001_betaT0.02
/home/user/venvs/myenv/lib/python3.9/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.
  WeightNorm.apply(module, name, dim)
SSSDS4Imputer Parameters: 48.371726M
No valid checkpoint model found, start training from initialization.
Data loaded
iteration: 0    loss: 1.0014768838882446

iteration: 100  loss: 0.519781768321991
model at iteration 100 is saved
iteration: 200  loss: 0.24966183304786682
model at iteration 200 is saved
iteration: 300  loss: 0.07377853989601135
model at iteration 300 is saved
iteration: 400  loss: 0.042258016765117645
model at iteration 400 is saved
iteration: 500  loss: 0.03209486976265907
model at iteration 500 is saved
iteration: 600  loss: 0.03080647811293602
model at iteration 600 is saved
iteration: 700  loss: 0.028232136741280556
model at iteration 700 is saved
iteration: 800  loss: 0.03111139126121998
model at iteration 800 is saved
iteration: 900  loss: 0.027907274663448334
model at iteration 900 is saved
iteration: 1000         loss: 0.019923614338040352
model at iteration 1000 is saved
iteration: 1100         loss: 0.027523808181285858
model at iteration 1100 is saved
iteration: 1200         loss: 0.021600928157567978
model at iteration 1200 is saved
iteration: 1300         loss: 0.025264695286750793
model at iteration 1300 is saved
iteration: 1400         loss: 0.02551073208451271
model at iteration 1400 is saved
iteration: 1500         loss: 0.01813984289765358
model at iteration 1500 is saved
iteration: 1600         loss: 0.016869256272912025
model at iteration 1600 is saved
iteration: 1700         loss: 0.02401333674788475
model at iteration 1700 is saved
iteration: 1800         loss: 0.019163137301802635
model at iteration 1800 is saved
iteration: 1900         loss: 0.020734122022986412
model at iteration 1900 is saved
iteration: 2000         loss: 0.021640272811055183
model at iteration 2000 is saved
iteration: 2100         loss: 0.01770048215985298
model at iteration 2100 is saved
iteration: 2200         loss: 0.018018294125795364
model at iteration 2200 is saved
iteration: 2300         loss: 0.024120550602674484
model at iteration 2300 is saved
iteration: 2400         loss: 0.020138196647167206
model at iteration 2400 is saved
iteration: 2500         loss: 0.017838910222053528
model at iteration 2500 is saved
iteration: 2600         loss: 0.01506974920630455
model at iteration 2600 is saved
iteration: 2700         loss: 0.014618837274610996
model at iteration 2700 is saved
iteration: 2800         loss: 0.012194080278277397
model at iteration 2800 is saved
iteration: 2900         loss: 0.017771299928426743
model at iteration 2900 is saved
iteration: 3000         loss: 0.016776636242866516
model at iteration 3000 is saved
iteration: 3100         loss: 0.016135280951857567
model at iteration 3100 is saved
iteration: 3200         loss: 0.02073829062283039
model at iteration 3200 is saved
iteration: 3300         loss: 0.01661628857254982
model at iteration 3300 is saved
iteration: 3400         loss: 0.014009429141879082
model at iteration 3400 is saved
iteration: 3500         loss: 0.017550712451338768
model at iteration 3500 is saved
iteration: 3600         loss: 0.01515461690723896
model at iteration 3600 is saved
iteration: 3700         loss: 0.011458991095423698
model at iteration 3700 is saved
iteration: 3800         loss: 0.019400596618652344
model at iteration 3800 is saved
iteration: 3900         loss: 0.017088860273361206
model at iteration 3900 is saved
iteration: 4000         loss: 0.017457306385040283
model at iteration 4000 is saved
iteration: 4100         loss: 0.016636019572615623
model at iteration 4100 is saved
iteration: 4200         loss: 0.013571725226938725
model at iteration 4200 is saved
iteration: 4300         loss: 0.011567792855203152
model at iteration 4300 is saved
iteration: 4400         loss: 0.01045661885291338
model at iteration 4400 is saved
iteration: 4500         loss: 0.010916751809418201
model at iteration 4500 is saved
iteration: 4600         loss: 0.009613706730306149
model at iteration 4600 is saved
iteration: 4700         loss: 0.019459472969174385
model at iteration 4700 is saved
iteration: 4800         loss: 0.01087689958512783
model at iteration 4800 is saved
iteration: 4900         loss: 0.013567390851676464
model at iteration 4900 is saved
iteration: 5000         loss: 0.013650226406753063
model at iteration 5000 is saved
iteration: 5100         loss: 0.01359144039452076
model at iteration 5100 is saved
iteration: 5200         loss: 0.00924278236925602
model at iteration 5200 is saved
iteration: 5300         loss: 0.012320103123784065
model at iteration 5300 is saved
iteration: 5400         loss: 0.010287933051586151
model at iteration 5400 is saved
iteration: 5500         loss: 0.01408147718757391
model at iteration 5500 is saved
iteration: 5600         loss: 0.008416304364800453
model at iteration 5600 is saved
iteration: 5700         loss: 0.00916117150336504
model at iteration 5700 is saved
iteration: 5800         loss: 0.010804965160787106
model at iteration 5800 is saved
iteration: 5900         loss: 0.01115892268717289
model at iteration 5900 is saved
iteration: 6000         loss: 0.01513115968555212
model at iteration 6000 is saved
iteration: 6100         loss: 0.009549148380756378
model at iteration 6100 is saved
iteration: 6200         loss: 0.00874475110322237
model at iteration 6200 is saved
iteration: 6300         loss: 0.011548626236617565
model at iteration 6300 is saved
iteration: 6400         loss: 0.009166202507913113
model at iteration 6400 is saved
iteration: 6500         loss: 0.00808231346309185
model at iteration 6500 is saved
iteration: 6600         loss: 0.010756433941423893
model at iteration 6600 is saved
iteration: 6700         loss: 0.010254732333123684
model at iteration 6700 is saved
iteration: 6800         loss: 0.00756409578025341
model at iteration 6800 is saved
iteration: 6900         loss: 0.009908380918204784
model at iteration 6900 is saved
iteration: 7000         loss: 0.0076458347029984
model at iteration 7000 is saved
iteration: 7100         loss: 0.010375996120274067
model at iteration 7100 is saved
iteration: 7200         loss: 0.006446328014135361
model at iteration 7200 is saved
iteration: 7300         loss: 0.009513751603662968
model at iteration 7300 is saved
iteration: 7400         loss: 0.007698349189013243
model at iteration 7400 is saved
iteration: 7500         loss: 0.011298432014882565
model at iteration 7500 is saved
iteration: 7600         loss: 0.009159176610410213
model at iteration 7600 is saved
iteration: 7700         loss: 0.009985801763832569
model at iteration 7700 is saved
iteration: 7800         loss: 0.008555536158382893
model at iteration 7800 is saved
iteration: 7900         loss: 0.008304151706397533
model at iteration 7900 is saved
iteration: 8000         loss: 0.009995604865252972
model at iteration 8000 is saved
iteration: 8100         loss: 0.010978537611663342
model at iteration 8100 is saved
iteration: 8200         loss: 0.009314262308180332
model at iteration 8200 is saved
iteration: 8300         loss: 0.012210099026560783
model at iteration 8300 is saved
iteration: 8400         loss: 0.014000611379742622
model at iteration 8400 is saved
iteration: 8500         loss: 0.008319240994751453
model at iteration 8500 is saved
iteration: 8600         loss: 0.009704609401524067
model at iteration 8600 is saved
iteration: 8700         loss: 0.01124562043696642
model at iteration 8700 is saved
iteration: 8800         loss: 0.00955492164939642
model at iteration 8800 is saved
iteration: 8900         loss: 0.008445963263511658
model at iteration 8900 is saved
iteration: 9000         loss: 0.007299567572772503
model at iteration 9000 is saved

執行 \SSSD-main\src\inference.py 將訓練完的模型 checkpoint 導入,進行預測。

在終端機執行

1
python inference.py --config /mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/config/config_SSSDS4.json --ckpt_iter max --num_samples 500
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import os
import argparse
import json
import numpy as np
import torch

from utils.util import get_mask_mnr, get_mask_bm, get_mask_rm
from utils.util import find_max_epoch, print_size, sampling, calc_diffusion_hyperparams

from imputers.DiffWaveImputer import DiffWaveImputer
from imputers.SSSDSAImputer import SSSDSAImputer
from imputers.SSSDS4Imputer import SSSDS4Imputer

from sklearn.metrics import mean_squared_error
from statistics import mean


def generate(output_directory,
             num_samples,
             ckpt_path,
             data_path,
             ckpt_iter,
             use_model,
             masking,
             missing_k,
             only_generate_missing):
  
    """
    Generate data based on ground truth 

    Parameters:
    output_directory (str):           save generated speeches to this path
    num_samples (int):                number of samples to generate, default is 4
    ckpt_path (str):                  checkpoint path
    ckpt_iter (int or 'max'):         the pretrained checkpoint to be loaded; 
                                      automitically selects the maximum iteration if 'max' is selected
    data_path (str):                  path to dataset, numpy array.
    use_model (int):                  0:DiffWave. 1:SSSDSA. 2:SSSDS4.
    masking (str):                    'mnr': missing not at random, 'bm': black-out, 'rm': random missing
    only_generate_missing (int):      0:all sample diffusion.  1:only apply diffusion to missing portions of the signal
    missing_k (int)                   k missing time points for each channel across the length.
    """

    # generate experiment (local) path
    local_path = "T{}_beta0{}_betaT{}".format(diffusion_config["T"],
                                              diffusion_config["beta_0"],
                                              diffusion_config["beta_T"])

    # Get shared output_directory ready
    output_directory = os.path.join(output_directory, local_path)
    if not os.path.isdir(output_directory):
        os.makedirs(output_directory)
        os.chmod(output_directory, 0o775)
    print("output directory", output_directory, flush=True)

    # map diffusion hyperparameters to gpu
    for key in diffusion_hyperparams:
        if key != "T":
            diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda()

    
    # predefine model
    if use_model == 0:
        net = DiffWaveImputer(**model_config).cuda()
    elif use_model == 1:
        net = SSSDSAImputer(**model_config).cuda()
    elif use_model == 2:
        net = SSSDS4Imputer(**model_config).cuda()
    else:
        print('Model chosen not available.')
    print_size(net)

  
    # load checkpoint
    ckpt_path = os.path.join(ckpt_path, local_path)
    if ckpt_iter == 'max':
        ckpt_iter = find_max_epoch(ckpt_path)
    model_path = os.path.join(ckpt_path, '{}.pkl'.format(ckpt_iter))
    try:
        checkpoint = torch.load(model_path, map_location='cpu')
        net.load_state_dict(checkpoint['model_state_dict'])
        print('Successfully loaded model at iteration {}'.format(ckpt_iter))
    except:
        raise Exception('No valid model found')

  
  
    ### Custom data loading and reshaping ###
  
    testing_data = np.load(trainset_config['test_data_path'])
    testing_data = np.split(testing_data, 4, 0)
    testing_data = np.array(testing_data)
    testing_data = torch.from_numpy(testing_data).float().cuda()
    print('Data loaded')

    all_mse = []

  
    for i, batch in enumerate(testing_data):

        if masking == 'mnr':
            mask_T = get_mask_mnr(batch[0], missing_k)
            mask = mask_T.permute(1, 0)
            mask = mask.repeat(batch.size()[0], 1, 1)
            mask = mask.type(torch.float).cuda()

        elif masking == 'bm':
            mask_T = get_mask_bm(batch[0], missing_k)
            mask = mask_T.permute(1, 0)
            mask = mask.repeat(batch.size()[0], 1, 1)
            mask = mask.type(torch.float).cuda()

        elif masking == 'rm':
            mask_T = get_mask_rm(batch[0], missing_k)
            mask = mask_T.permute(1, 0)
            mask = mask.repeat(batch.size()[0], 1, 1).float().cuda()

    
    
        batch = batch.permute(0,2,1)
  
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()

        sample_length = batch.size(2)
        sample_channels = batch.size(1)
        generated_audio = sampling(net, (num_samples, sample_channels, sample_length),
                                   diffusion_hyperparams,
                                   cond=batch,
                                   mask=mask,
                                   only_generate_missing=only_generate_missing)

        end.record()
        torch.cuda.synchronize()

        print('generated {} utterances of random_digit at iteration {} in {} seconds'.format(num_samples,
                                                                                             ckpt_iter,
                                                                                             int(start.elapsed_time(
                                                                                                 end) / 1000)))

  
        generated_audio = generated_audio.detach().cpu().numpy()
        batch = batch.detach().cpu().numpy()
        mask = mask.detach().cpu().numpy() 
  
  
        outfile = f'imputation{i}.npy'
        new_out = os.path.join(ckpt_path, outfile)
        np.save(new_out, generated_audio)

        outfile = f'original{i}.npy'
        new_out = os.path.join(ckpt_path, outfile)
        np.save(new_out, batch)

        outfile = f'mask{i}.npy'
        new_out = os.path.join(ckpt_path, outfile)
        np.save(new_out, mask)

        print('saved generated samples at iteration %s' % ckpt_iter)
  
        mse = mean_squared_error(generated_audio[~mask.astype(bool)], batch[~mask.astype(bool)])
        all_mse.append(mse)
  
    print('Total MSE:', mean(all_mse))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str, default='config.json',
                        help='JSON file for configuration')
    parser.add_argument('-ckpt_iter', '--ckpt_iter', default='max',
                        help='Which checkpoint to use; assign a number or "max"')
    parser.add_argument('-n', '--num_samples', type=int, default=500,
                        help='Number of utterances to be generated')
    args = parser.parse_args()

    # Parse configs. Globals nicer in this case
    with open(args.config) as f:
        data = f.read()
    config = json.loads(data)
    print(config)

    gen_config = config['gen_config']

    train_config = config["train_config"]  # training parameters

    global trainset_config
    trainset_config = config["trainset_config"]  # to load trainset

    global diffusion_config
    diffusion_config = config["diffusion_config"]  # basic hyperparameters

    global diffusion_hyperparams
    diffusion_hyperparams = calc_diffusion_hyperparams(
        **diffusion_config)  # dictionary of all diffusion hyperparameters

    global model_config
    if train_config['use_model'] == 0:
        model_config = config['wavenet_config']
    elif train_config['use_model'] == 1:
        model_config = config['sashimi_config']
    elif train_config['use_model'] == 2:
        model_config = config['wavenet_config']

    generate(**gen_config,
             ckpt_iter=args.ckpt_iter,
             num_samples=args.num_samples,
             use_model=train_config["use_model"],
             data_path=trainset_config["test_data_path"],
             masking=train_config["masking"],
             missing_k=train_config["missing_k"],
             only_generate_missing=train_config["only_generate_missing"])
執行結果參考
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
(myenv) user@LAPTOP-KOPTLCHM:/mnt/d/Code/sssd_cp_learning_and_testing/SSSD-main/src$ python inference.py --config /mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/config/config_SSSDS4.json --ckpt_iter max --num_samples 500
^[[ACUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%
Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency.
{'diffusion_config': {'T': 200, 'beta_0': 0.0001, 'beta_T': 0.02}, 'wavenet_config': {'in_channels': 14, 'out_channels': 14, 'num_res_layers': 36, 'res_channels': 256, 'skip_channels': 256, 'diffusion_step_embed_dim_in': 128, 'diffusion_step_embed_dim_mid': 512, 'diffusion_step_embed_dim_out': 512, 's4_lmax': 100, 's4_d_state': 64, 's4_dropout': 0.0, 's4_bidirectional': 1, 's4_layernorm': 1}, 'train_config': {'output_directory': '/mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/results/mujoco/90/', 'ckpt_iter': 'max', 'iters_per_ckpt': 100, 'iters_per_logging': 100, 'n_iters': 1000, 'learning_rate': 0.0002, 'only_generate_missing': 1, 'use_model': 2, 'masking': 'rm', 'missing_k': 90}, 'trainset_config': {'train_data_path': '/mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/datasets/Mujoco/train_mujoco.npy', 'test_data_path': '/mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/datasets/Mujoco/test_mujoco.npy', 'segment_length': 100, 'sampling_rate': 100}, 'gen_config': {'output_directory': '/mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/results/mujoco/90/', 'ckpt_path': '/mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/results/mujoco/90/'}}
output directory /mnt/d/Code/sssd_cp_learning_and_testing/learning_and_testing/SSSD/results/mujoco/90/T200_beta00.0001_betaT0.02
/home/user/venvs/myenv/lib/python3.9/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.
  WeightNorm.apply(module, name, dim)
SSSDS4Imputer Parameters: 48.371726M
Successfully loaded model at iteration 13100
Data loaded
begin sampling, total number of reverse steps = 200
generated 500 utterances of random_digit at iteration 13100 in 522 seconds
saved generated samples at iteration 13100
begin sampling, total number of reverse steps = 200
generated 500 utterances of random_digit at iteration 13100 in 522 seconds
saved generated samples at iteration 13100
begin sampling, total number of reverse steps = 200
generated 500 utterances of random_digit at iteration 13100 in 522 seconds
saved generated samples at iteration 13100
begin sampling, total number of reverse steps = 200
generated 500 utterances of random_digit at iteration 13100 in 522 seconds
saved generated samples at iteration 13100
Total MSE: 0.009962185751646757

由本研究衍伸的學習:Python 模組 PyTorch 簡介與基礎語法

參考資料