1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
| import os
import argparse
import json
import numpy as np
import torch
from utils.util import get_mask_mnr, get_mask_bm, get_mask_rm
from utils.util import find_max_epoch, print_size, sampling, calc_diffusion_hyperparams
from imputers.DiffWaveImputer import DiffWaveImputer
from imputers.SSSDSAImputer import SSSDSAImputer
from imputers.SSSDS4Imputer import SSSDS4Imputer
from sklearn.metrics import mean_squared_error
from statistics import mean
def generate(output_directory,
num_samples,
ckpt_path,
data_path,
ckpt_iter,
use_model,
masking,
missing_k,
only_generate_missing):
"""
Generate data based on ground truth
Parameters:
output_directory (str): save generated speeches to this path
num_samples (int): number of samples to generate, default is 4
ckpt_path (str): checkpoint path
ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded;
automitically selects the maximum iteration if 'max' is selected
data_path (str): path to dataset, numpy array.
use_model (int): 0:DiffWave. 1:SSSDSA. 2:SSSDS4.
masking (str): 'mnr': missing not at random, 'bm': black-out, 'rm': random missing
only_generate_missing (int): 0:all sample diffusion. 1:only apply diffusion to missing portions of the signal
missing_k (int) k missing time points for each channel across the length.
"""
# generate experiment (local) path
local_path = "T{}_beta0{}_betaT{}".format(diffusion_config["T"],
diffusion_config["beta_0"],
diffusion_config["beta_T"])
# Get shared output_directory ready
output_directory = os.path.join(output_directory, local_path)
if not os.path.isdir(output_directory):
os.makedirs(output_directory)
os.chmod(output_directory, 0o775)
print("output directory", output_directory, flush=True)
# map diffusion hyperparameters to gpu
for key in diffusion_hyperparams:
if key != "T":
diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda()
# predefine model
if use_model == 0:
net = DiffWaveImputer(**model_config).cuda()
elif use_model == 1:
net = SSSDSAImputer(**model_config).cuda()
elif use_model == 2:
net = SSSDS4Imputer(**model_config).cuda()
else:
print('Model chosen not available.')
print_size(net)
# load checkpoint
ckpt_path = os.path.join(ckpt_path, local_path)
if ckpt_iter == 'max':
ckpt_iter = find_max_epoch(ckpt_path)
model_path = os.path.join(ckpt_path, '{}.pkl'.format(ckpt_iter))
try:
checkpoint = torch.load(model_path, map_location='cpu')
net.load_state_dict(checkpoint['model_state_dict'])
print('Successfully loaded model at iteration {}'.format(ckpt_iter))
except:
raise Exception('No valid model found')
### Custom data loading and reshaping ###
testing_data = np.load(trainset_config['test_data_path'])
testing_data = np.split(testing_data, 4, 0)
testing_data = np.array(testing_data)
testing_data = torch.from_numpy(testing_data).float().cuda()
print('Data loaded')
all_mse = []
for i, batch in enumerate(testing_data):
if masking == 'mnr':
mask_T = get_mask_mnr(batch[0], missing_k)
mask = mask_T.permute(1, 0)
mask = mask.repeat(batch.size()[0], 1, 1)
mask = mask.type(torch.float).cuda()
elif masking == 'bm':
mask_T = get_mask_bm(batch[0], missing_k)
mask = mask_T.permute(1, 0)
mask = mask.repeat(batch.size()[0], 1, 1)
mask = mask.type(torch.float).cuda()
elif masking == 'rm':
mask_T = get_mask_rm(batch[0], missing_k)
mask = mask_T.permute(1, 0)
mask = mask.repeat(batch.size()[0], 1, 1).float().cuda()
batch = batch.permute(0,2,1)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
sample_length = batch.size(2)
sample_channels = batch.size(1)
generated_audio = sampling(net, (num_samples, sample_channels, sample_length),
diffusion_hyperparams,
cond=batch,
mask=mask,
only_generate_missing=only_generate_missing)
end.record()
torch.cuda.synchronize()
print('generated {} utterances of random_digit at iteration {} in {} seconds'.format(num_samples,
ckpt_iter,
int(start.elapsed_time(
end) / 1000)))
generated_audio = generated_audio.detach().cpu().numpy()
batch = batch.detach().cpu().numpy()
mask = mask.detach().cpu().numpy()
outfile = f'imputation{i}.npy'
new_out = os.path.join(ckpt_path, outfile)
np.save(new_out, generated_audio)
outfile = f'original{i}.npy'
new_out = os.path.join(ckpt_path, outfile)
np.save(new_out, batch)
outfile = f'mask{i}.npy'
new_out = os.path.join(ckpt_path, outfile)
np.save(new_out, mask)
print('saved generated samples at iteration %s' % ckpt_iter)
mse = mean_squared_error(generated_audio[~mask.astype(bool)], batch[~mask.astype(bool)])
all_mse.append(mse)
print('Total MSE:', mean(all_mse))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='config.json',
help='JSON file for configuration')
parser.add_argument('-ckpt_iter', '--ckpt_iter', default='max',
help='Which checkpoint to use; assign a number or "max"')
parser.add_argument('-n', '--num_samples', type=int, default=500,
help='Number of utterances to be generated')
args = parser.parse_args()
# Parse configs. Globals nicer in this case
with open(args.config) as f:
data = f.read()
config = json.loads(data)
print(config)
gen_config = config['gen_config']
train_config = config["train_config"] # training parameters
global trainset_config
trainset_config = config["trainset_config"] # to load trainset
global diffusion_config
diffusion_config = config["diffusion_config"] # basic hyperparameters
global diffusion_hyperparams
diffusion_hyperparams = calc_diffusion_hyperparams(
**diffusion_config) # dictionary of all diffusion hyperparameters
global model_config
if train_config['use_model'] == 0:
model_config = config['wavenet_config']
elif train_config['use_model'] == 1:
model_config = config['sashimi_config']
elif train_config['use_model'] == 2:
model_config = config['wavenet_config']
generate(**gen_config,
ckpt_iter=args.ckpt_iter,
num_samples=args.num_samples,
use_model=train_config["use_model"],
data_path=trainset_config["test_data_path"],
masking=train_config["masking"],
missing_k=train_config["missing_k"],
only_generate_missing=train_config["only_generate_missing"])
|