目錄

$SSSD^{S4}$ 模型修改

1141224 meeting

以下整理最近所有 meeting ,並寫出所有修改處。

以下改動目的在於修改原 SSSD 模型的訓練迭代方式,使其在迭代時計算 autoFRK 的結果,並依此計算誤差調整模型。其流程如下:

修改

修改的程式碼如下:

config

/configs/model.yaml (無修改)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
wavenet:
  # WaveNet model parameters
  input_channels: 24  # Number of input channels
  output_channels: 24  # Number of output channels
  residual_layers: 32  # Number of residual layers
  residual_channels: 32  # Number of channels in residual blocks
  skip_channels: 32  # Number of channels in skip connections

  # Diffusion step embedding dimensions
  diffusion_step_embed_dim_input: 64  # Input dimension
  diffusion_step_embed_dim_hidden: 64  # Middle dimension
  diffusion_step_embed_dim_output: 64  # Output dimension

  # Structured State Spaces sequence model (S4) configurations
  s4_max_sequence_length: 292
    # Maximum sequence length
  s4_state_dim: 64  # State dimension
  s4_dropout: 0.0  # Dropout rate
  s4_bidirectional: true  # Whether to use bidirectional layers
  s4_use_layer_norm: true  # Whether to use layer normalization

diffusion:
  # Diffusion model parameters
  T: 200  # Number of diffusion steps
  beta_0: 0.0001  # Initial beta value
  beta_T: 0.02  # Final beta value

/configs/training.yaml (新增 autoFRK 相關設定)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Training configuration
batch_size: 64  # Batch size
output_directory: "/home/u6025091/SSSD_CP/results/NASA-GES-DISC-fast-rectangular"  # Output directory for checkpoints and logs
ckpt_iter: "max"  # Checkpoint mode (max or min)
iters_per_ckpt: 1000  # Checkpoint frequency (number of epochs)
iters_per_logging: 200  # Log frequency (number of iterations)
n_iters: 38000  # Maximum number of iterations
learning_rate: 0.0005  # Learning rate

# Additional training settings
only_generate_missing: true  # Generate missing values only
use_model: 2  # Model to use for training
masking: "forecast"  # Masking strategy for missing values
missing_k: 22  # Number of missing values

# Data paths
data:
  train_path: "/home/u6025091/SSSD_CP/datasets/NASA-GES-DISC/data_train_known_real.npy"  # Path to training data

# autoFRK config
enable_spatial_prediction: false  # Enable spatial prediction step
autoFRK_period: 20  # Frequency of autoFRK updates (in how many iterations)
location_path: "/home/u6025091/SSSD_CP/datasets/NASA-GES-DISC/stations_known_locations.npy"  # Path to known locations
AFRK_method: "fast"
AFRK_tps_method: "rectangular"

/configs/inference.yaml (新增 autoFRK 相關設定)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Inference configuration
batch_size: 64  # Batch size for inference
output_directory: "/home/u6025091/SSSD_CP/results/NASA-GES-DISC-fast-rectangular/inference"  # Output directory for inference results
ckpt_path: "/home/u6025091/SSSD_CP/results/NASA-GES-DISC-fast-rectangular"  # Path to checkpoint for inference
trials: 1 # Replications

# Additional training settings
only_generate_missing: true  # Generate missing values only
use_model: 2  # Model to use for training
masking: "forecast"  # Masking strategy for missing values
missing_k: 22  # Number of missing values

# Data paths
data:
  test_path: "/home/u6025091/SSSD_CP/datasets/NASA-GES-DISC/data_test_missing.npy"  # Path to test data

# autoFRK config
enable_spatial_inference: true  # Enable spatial prediction step
known_location_path: "/home/u6025091/SSSD_CP/datasets/NASA-GES-DISC/stations_known_locations.npy"  # Path to known locations
unknown_location_path: "/home/u6025091/SSSD_CP/datasets/NASA-GES-DISC/stations_unknown_locations.npy"  # Path to unknown locations
AFRK_method: "fast"
AFRK_tps_method: "rectangular"

training

以下均位於 sssd/training/ 目錄下。

trainer.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
import logging
import os
from typing import Any, Dict, Optional, Union, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.checkpoint import checkpoint  # New
from tqdm import tqdm
import numpy as np  # New
import csv  # New

from sssd.core.model_specs import MASK_FN
from sssd.training.utils import training_loss
from sssd.data.utils import get_dataloader
from sssd.utils.logger import setup_logger
from sssd.utils.utils import find_max_epoch, std_normal  # New
from autoFRK import AutoFRK, to_tensor, garbage_cleaner  # New
from autoFRK.utils.helper import cbrt  # New

LOGGER = setup_logger()


class DiffusionTrainer:
    """
    Train Diffusion Models

    Args:
        dataloader (DataLoader): The training dataloader.
        diffusion_hyperparams (Dict[str, Any]): Hyperparameters for the diffusion process.
        net (nn.Module): The neural network model to be trained.
        device (torch.device): The device to be used for training.
        output_directory (str): Directory to save model checkpoints.
        ckpt_iter (Optional[int, str]): The checkpoint iteration to be loaded; 'max' selects the maximum iteration.
        n_iters (int): Number of iterations to train.
        iters_per_ckpt (int): Number of iterations to save checkpoint.
        iters_per_logging (int): Number of iterations to save training log and compute validation loss.
        learning_rate (float): Learning rate for training.
        only_generate_missing (int): Option to generate missing portions of the signal only.
        masking (str): Type of masking strategy: 'mnr' for Missing Not at Random, 'bm' for Blackout Missing, 'rm' for Random Missing.
        missing_k (int): K missing time steps for each feature across the sample length.
        batch_size (int): Size of each training batch.
        logger (Optional[logging.Logger]): Logger object for logging, defaults to None.
    """

    def __init__(
        self,
        data_path: str,
        diffusion_hyperparams: Dict[str, Any],
        net: nn.Module,
        device: Optional[Union[torch.device, str]],
        output_directory: str,
        ckpt_iter: Union[str, int],
        n_iters: int,
        iters_per_ckpt: int,
        iters_per_logging: int,
        learning_rate: float,
        only_generate_missing: int,
        masking: str,
        missing_k: int,
        batch_size: int,
        enable_spatial_prediction: bool,  # New
        autoFRK_period: int,  # New
        location_path: str,  # New
        AFRK_method: str,  # New
        AFRK_tps_method: str,  # New
        logger: Optional[logging.Logger] = None,
    ) -> None:
        loader, ts_mean, ts_std = get_dataloader(
            path=data_path,
            batch_size=batch_size,
            device=device,
        )
        self.dataloader = loader
        self.ts_mean = ts_mean
        self.ts_std = ts_std
        self.diffusion_hyperparams = diffusion_hyperparams
        self.net = nn.DataParallel(net).to(device)
        self.device = device
        self.output_directory = output_directory
        self.ckpt_iter = ckpt_iter
        self.n_iters = n_iters
        self.iters_per_ckpt = iters_per_ckpt
        self.iters_per_logging = iters_per_logging
        self.learning_rate = learning_rate
        self.only_generate_missing = only_generate_missing
        self.masking = masking
        self.missing_k = missing_k
        self.writer = SummaryWriter(f"{output_directory}/log")
        self.batch_size = batch_size
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
        self.enable_spatial_prediction = enable_spatial_prediction  # New
        self.autoFRK_period = autoFRK_period  # New
        self.location_path = location_path  # New
        self.real_data = self.dataloader.dataset.tensors[0].to(self.device) # * self.ts_std + self.ts_mean # New
        self.real_data_shape = self.real_data.shape  # New
        self.logger = logger or LOGGER
        self.AFRK_mrts = None  # New
        self.AFRK_method = AFRK_method  # New
        self.AFRK_tps_method = AFRK_tps_method  # New
        self.loc = to_tensor(np.load(self.location_path), dtype=self.real_data.dtype, device=device)

        frk = AutoFRK(
            logger_level=30,
            dtype=self.real_data.dtype,
            device=self.real_data.device,
        )
        
        # autoFRK
        def frk_step(data_slice, mrts):
            _ = frk.forward(
                data=data_slice,
                loc=self.loc,
                G=mrts,
                method=self.AFRK_method,
                tps_method=self.AFRK_tps_method,
                requires_grad=True
            )
            pred = frk.predict()['pred.value']
            return pred.T, frk.obj['G']
        #self.frk_step = torch.compile(frk_step)
        self.frk_step = frk_step

        if self.masking not in MASK_FN:
            raise KeyError(f"Please enter a correct masking, but got {self.masking}")

    def _load_checkpoint(self) -> None:
        if self.ckpt_iter == "max":
            self.ckpt_iter = find_max_epoch(self.output_directory)
        if self.ckpt_iter >= 0:
            try:
                model_path = os.path.join(
                    self.output_directory, f"{self.ckpt_iter}.pkl"
                )
                checkpoint = torch.load(model_path, map_location="cpu")

                self.net.load_state_dict(checkpoint["model_state_dict"])
                if "optimizer_state_dict" in checkpoint:
                    self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

                self.logger.info(
                    f"Successfully loaded model at iteration {self.ckpt_iter}"
                )
            except Exception as e:
                self.ckpt_iter = -1
                self.logger.error(f"No valid checkpoint model found. Error: {e}")
        else:
            self.ckpt_iter = -1
            self.logger.info(
                "No valid checkpoint model found, start training from initialization."
            )

    def _save_model(self, n_iter: int) -> None:
        if n_iter > 0 and n_iter % self.iters_per_ckpt == 0:
            torch.save(
                {
                    "model_state_dict": self.net.state_dict(),
                    "optimizer_state_dict": self.optimizer.state_dict(),
                },
                os.path.join(self.output_directory, f"{n_iter}.pkl"),
            )

    def _update_mask(self, batch: torch.Tensor) -> torch.Tensor:
        transposed_mask = MASK_FN[self.masking](batch[0], self.missing_k)
        return (
            transposed_mask.permute(1, 0)
            .repeat(batch.size()[0], 1, 1)
            .to(self.device, dtype=torch.float32)
        )
    
    def _sampling_differentiable(
        self,
        net: torch.nn.Module,
        size: Tuple[int, int, int],
        diffusion_hyperparams: Dict[str, torch.Tensor],
        cond: torch.Tensor,
        mask: torch.Tensor,
        only_generate_missing: int = 0,
        device: Union[torch.device, str] = "cpu",
        sampling_mode: str = "ddpm",  # ✅ "ddpm" 或 "ddim"
        step_skip: Union[int, None] = None  # ✅ 若為 None,則自動偵測
    ) -> torch.Tensor:
        _dh = diffusion_hyperparams
        T, Alpha, Alpha_bar, Sigma = _dh["T"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"]
        x = torch.randn(size, device=device)
        mask = mask.to(device)
        cond = cond.to(device)

        # ✅ 自動偵測最適 step_skip
        if step_skip is None:
            if T <= 200:
                step_skip = 1       # 小模型:全步精細推論
            elif T <= 1000:
                step_skip = max(1, T // 200)  # 一般模型:約 200 步內完成
            else:
                step_skip = max(1, T // 400)  # 大型模型:控制在 400 步以內
            #LOGGER.info(f"T={T}, using step_skip={step_skip}")
        step_skip = max(1, step_skip)

        # ✅ 建立跳步時間表
        timesteps = list(range(T - 1, -1, -step_skip))
        if timesteps[-1] != 0:
            timesteps.append(0)

        noise_list = []
        if sampling_mode == "ddpm":
            for _ in timesteps:
                # pre-generate noise with same shape as x; deterministic for forward+recompute
                noise_list.append(torch.randn(size, device=device))
        else:
            # keep placeholder list (not used for ddim)
            noise_list = [None] * len(timesteps)

        for idx, t in enumerate(timesteps):
            noise_t = noise_list[idx]
            def net_step(x_inner, noise=noise_t, tt=t):
                if only_generate_missing == 1:
                    x_inner = x_inner * (1 - mask).float() + cond * mask.float()
                diffusion_steps = (tt * torch.ones((size[0], 1), device=device))
                epsilon_theta = net((x_inner, cond, mask, diffusion_steps))

                # ✅ 根據 sampling_mode 決定取樣公式
                if sampling_mode == "ddpm":
                    # DDPM:隨機性版本
                    x_next = (x_inner - (1 - Alpha[tt]) / torch.sqrt(1 - Alpha_bar[tt]) * epsilon_theta) / torch.sqrt(Alpha[tt])
                    if tt > 0:
                        x_next = x_next + Sigma[tt] * noise
                elif sampling_mode == "ddim":
                    # DDIM:確定性版本,支援跳步
                    if tt == 0:
                        x_next = torch.sqrt(Alpha_bar[tt]) * (
                            (x_inner - torch.sqrt(1 - Alpha_bar[tt]) * epsilon_theta)
                            / torch.sqrt(Alpha_bar[tt])
                        )
                    else:
                        Alpha_bar_prev = Alpha_bar[max(tt - step_skip, 0)]
                        x0_pred = (x_inner - torch.sqrt(1 - Alpha_bar[tt]) * epsilon_theta) / torch.sqrt(Alpha_bar[tt])
                        x_next = torch.sqrt(Alpha_bar_prev) * x0_pred + torch.sqrt(1 - Alpha_bar_prev) * epsilon_theta
                else:
                    raise ValueError(f"Unknown sampling_mode: {sampling_mode}")

                return x_next
            x = checkpoint(net_step, x, use_reentrant=False)
            #x = net_step(x)

        return x
    
    def _sssd_prediction_step(self) -> torch.Tensor:
        self.logger.info(f"Start SSSD prediction step")

        all_batches = []
        for (batch,) in tqdm(self.dataloader, desc=f"{self.n_iter}-th predicting TS"):
            mask = self._update_mask(batch)
            batch = batch.permute(0, 2, 1)

            batch_generated = self._sampling_differentiable(
                net=self.net,
                size=batch.shape,
                diffusion_hyperparams=self.diffusion_hyperparams,
                cond=batch,
                mask=mask,
                only_generate_missing=self.only_generate_missing,
                device=self.device,
                sampling_mode=getattr(self, "sampling_mode", "ddim"),
                step_skip=getattr(self, "step_skip", 10)
            )
            all_batches.append(batch_generated)

        sssd_prediction = torch.cat(all_batches, dim=0).permute(1, 2, 0)

        return sssd_prediction

    def _autoFRK_step(self,
                      sssd_prediction,
                      ) -> torch.Tensor:
        # autoFRK
        self.logger.info(f"Start autoFRK inference step")
        dtype = sssd_prediction.dtype
        device = sssd_prediction.device
        
        V, T, N = sssd_prediction.shape
        autoFRK_result = torch.empty_like(sssd_prediction, dtype=dtype, device=device)
        frk_step = self.frk_step

        # mrts
        if self.AFRK_mrts is None:
            data_slice = sssd_prediction[0, :, :].T
            _, mrts = frk_step(data_slice, None)
            self.AFRK_mrts = mrts
        else:
            mrts = self.AFRK_mrts

        # Flatten 迴圈
        for idx in tqdm(range(V), desc=f"{self.n_iter}-th predicting autoFRK"):
            data_slice = sssd_prediction[idx, :, :].T
            try:
                pred, _ = frk_step(data_slice, mrts)
            except Exception as e:
                self.logger.warning(f"Variable {idx} numerical issue: {e}; using fallback clone")
                pred = data_slice.clone()
            autoFRK_result[idx, :, :] = pred
        nan_idx = torch.isnan(autoFRK_result)
        if nan_idx.any():
            self.logger.info(f"Replacing total {nan_idx.sum().item()} NaNs in autoFRK_result with original sssd_prediction values")
            autoFRK_result = torch.where(nan_idx, sssd_prediction, autoFRK_result)
        autoFRK_result = autoFRK_result.permute(2, 1, 0)

        # return
        if autoFRK_result.shape != self.real_data_shape:
            error_msg = f"Shape mismatch: autoFRK_result {autoFRK_result.shape} != real_data {self.real_data_shape}"
            self.logger.error(error_msg)
            raise ValueError(error_msg)

        return autoFRK_result

    def _train_per_epoch(self) -> torch.Tensor:
        loss_function=nn.MSELoss()
        #diffusion_loss = torch.tensor(0.0, device=self.device, dtype=torch.float32)
        #n_batches = 0

        # SSSD training
        for batch_idx, (batch,) in enumerate(tqdm(self.dataloader, desc=f"{self.n_iter}-th training TS")):
            batch = batch.to(self.device)

            # New
            mask = self._update_mask(batch)
            loss_mask = ~mask.bool()
            self.optimizer.zero_grad()

            if loss_mask.sum() == 0:
                self.logger.warning(f"Batch {batch_idx} has no valid elements for loss")
                continue

            batch = batch.permute(0, 2, 1)
            assert batch.size() == mask.size() == loss_mask.size()
            
            loss = training_loss(
                model=self.net,
                loss_function=loss_function,
                training_data=(batch, batch, mask, loss_mask),
                diffusion_parameters=self.diffusion_hyperparams,
                generate_only_missing=self.only_generate_missing,
                device=self.device,
            )

            if not self.enable_spatial_prediction or self.n_iter % self.autoFRK_period != 0:
                loss.backward()
            else:
                loss.backward(retain_graph=True)
            self.optimizer.step()

            #diffusion_loss = diffusion_loss + loss
            #n_batches += 1

        if self.enable_spatial_prediction and self.n_iter % self.autoFRK_period == 0:
            self.logger.info(f"Iteration {self.n_iter}: Start Spatial Prediction step")
            sssd_prediction = self._sssd_prediction_step()
            autoFRK_result = self._autoFRK_step(sssd_prediction=sssd_prediction)

            # compute loss
            self.optimizer.zero_grad()
            spatial_loss = loss_function(
                autoFRK_result,
                self.real_data
            )
            #loss = spatial_loss + diffusion_loss / n_batches
            loss = spatial_loss

            # update model
            loss.backward()
            self.optimizer.step()
            
        self.logger.info(f"Iteration {self.n_iter}: loss: {loss.item()}")
        
        return loss

    def train(self) -> None:
        self._load_checkpoint()

        n_iter_start = (
            self.ckpt_iter + 2 if self.ckpt_iter == -1 else self.ckpt_iter + 1
        )
        self.logger.info(f"Start the {n_iter_start} iteration")

        for n_iter in range(n_iter_start, self.n_iters + 1):
            self.n_iter = n_iter
            loss = self._train_per_epoch()
            self.writer.add_scalar("Train/Loss", loss.item(), n_iter)
            if n_iter % self.iters_per_logging == 0:
                self.logger.info(f"Iteration: {n_iter} \tLoss: { loss.item()}")
            self._save_model(n_iter)

utils.py (無修改)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import Dict, Tuple

import torch

from sssd.utils.utils import std_normal
from sssd.utils.logger import setup_logger

LOGGER = setup_logger()

def training_loss(
    model: torch.nn.Module,
    loss_function: torch.nn.Module,
    training_data: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
    diffusion_parameters: Dict[str, torch.Tensor],
    generate_only_missing: int = 1,
    device: str = "cpu",
) -> torch.Tensor:
    """
    Compute the training loss of epsilon and epsilon_theta.

    Args:
        model (torch.nn.Module): The neural network model.
        loss_function (torch.nn.Module): The loss function, default is nn.MSELoss().
        training_data (tuple): Training data tuple containing (time_series, condition, mask, loss_mask).
        diffusion_parameters (dict): Dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams.
                                     Note, the tensors need to be cuda tensors.
        generate_only_missing (int): Flag to indicate whether to only generate missing values (default=1).
        device (str): Device to run the computations on (default="cuda").

    Returns:
        torch.Tensor: Training loss.
    """

    # Unpack diffusion hyperparameters
    T, alpha_bar = diffusion_parameters["T"], diffusion_parameters["Alpha_bar"]

    # Unpack training data
    time_series, condition, mask, loss_mask = training_data

    batch_size = time_series.shape[0]

    # Sample random diffusion steps for each batch element
    diffusion_steps = torch.randint(T, size=(batch_size, 1, 1)).to(device)
    if torch.isnan(diffusion_steps).any():
        LOGGER.warning("diffusion_steps contains NaN")

    # Generate Gaussian noise, applying mask if specified
    noise = (
        time_series * mask.float()
        + std_normal(time_series.shape, device) * (1 - mask).float()
        if generate_only_missing
        else std_normal(time_series.shape, device)
    )
    if torch.isnan(noise).any():
        LOGGER.warning("noise contains NaN")
        LOGGER.info(f"noise stats: min={noise.min().item()}, max={noise.max().item()}, mean={noise.mean().item()}")

    # Compute x_t from q(x_t|x_0)
    transformed_series = (
        torch.sqrt(alpha_bar[diffusion_steps]) * time_series
        + torch.sqrt(1 - alpha_bar[diffusion_steps]) * noise
    )
    if torch.isnan(transformed_series).any():
        LOGGER.warning("transformed_series contains NaN")
        LOGGER.info(f"transformed_series stats: min={transformed_series.min().item()}, max={transformed_series.max().item()}, mean={transformed_series.mean().item()}")

    # Predict epsilon according to epsilon_theta
    epsilon_theta = model(
        (transformed_series, condition, mask, diffusion_steps.view(batch_size, 1))
    )
    if torch.isnan(epsilon_theta).any():
        LOGGER.warning("epsilon_theta contains NaN")
        LOGGER.info(f"epsilon_theta stats: min={epsilon_theta.min().item()}, max={epsilon_theta.max().item()}, mean={epsilon_theta.mean().item()}")

    # # Compute loss
    # if generate_only_missing:
    #     return loss_function(epsilon_theta[loss_mask], noise[loss_mask])
    # else:
    #     return loss_function(epsilon_theta, noise)
    # Compute loss
    if generate_only_missing:
        loss = loss_function(epsilon_theta[loss_mask], noise[loss_mask])
    else:
        loss = loss_function(epsilon_theta, noise)

    if torch.isnan(loss):
        LOGGER.warning("loss contains NaN")
        LOGGER.info(f"loss value: {loss.item()}")

    return loss

inference

以下均位於 sssd/inference/ 目錄下。 generator.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import logging
import os
from typing import Dict, Iterable, Optional, Union

import numpy as np
import torch
from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error
from torch.utils.data import DataLoader
from tqdm import tqdm  # New

from sssd.core.model_specs import MASK_FN
from sssd.data.utils import get_dataloader  # New
from sssd.utils.logger import setup_logger
from sssd.utils.utils import find_max_epoch, sampling
from autoFRK import AutoFRK, to_tensor, garbage_cleaner  # New

LOGGER = setup_logger()


class DiffusionGenerator:
    """
    Generate data based on ground truth.

    Args:
        net (torch.nn.Module): The neural network model.
        device (Optional[Union[torch.device, str]]): The device to run the model on (e.g., 'cuda' or 'cpu').
        diffusion_hyperparams (dict): Dictionary of diffusion hyperparameters.
        local_path (str): Local path format for the model.
        testing_data (torch.Tensor): Tensor containing testing data.
        output_directory (str): Path to save generated samples.
        batch_size (int): Number of samples to generate.
        ckpt_path (str): Checkpoint directory.
        ckpt_iter (str): Pretrained checkpoint to load; 'max' selects the maximum iteration.
        masking (str): Type of masking: 'mnr' (missing not at random), 'bm' (black-out), 'rm' (random missing).
        missing_k (int): Number of missing time points for each channel across the length.
        only_generate_missing (int): Whether to generate only missing portions of the signal:
                                      - 0 (all sample diffusion),
                                      - 1 (generate missing portions only).
        saved_data_names (Iterable[str], optional): Names of data arrays to save (default is ("imputation", "original", "mask")).
        logger (Optional[logging.Logger], optional): Logger object for logging messages (default is None).
    """

    def __init__(
        self,
        net: torch.nn.Module,
        device: Optional[Union[torch.device, str]],
        diffusion_hyperparams: dict,
        local_path: str,
        data_path: str,
        output_directory: str,
        batch_size: int,
        ckpt_path: str,
        ckpt_iter: str,
        masking: str,
        missing_k: int,
        only_generate_missing: int,
        enable_spatial_prediction: bool,  # New
        known_location_path: str,  # New
        unknown_location_path: str,  # New
        AFRK_method: str,  # New
        AFRK_tps_method: str,  # New
        saved_data_names: Iterable[str] = ("imputation", "original", "mask"),
        logger: Optional[logging.Logger] = None,
    ):
        self.net = net
        self.device = device
        self.diffusion_hyperparams = diffusion_hyperparams
        self.local_path = local_path
        loader, ts_mean, ts_std = get_dataloader(
            path=data_path,
            batch_size=batch_size,
            device=device,
            inference=True,
            missing_k=missing_k
        )
        self.dataloader = loader
        self.ts_mean = ts_mean
        self.ts_std = ts_std
        self.batch_size = batch_size
        self.masking = masking
        self.missing_k = missing_k
        self.only_generate_missing = only_generate_missing
        self.enable_spatial_prediction = enable_spatial_prediction  # New
        self.known_location_path = known_location_path  # New
        self.unknown_location_path = unknown_location_path  # New
        self.missing_k = missing_k  # New
        self.AFRK_method = AFRK_method  # New
        self.AFRK_tps_method = AFRK_tps_method # New
        self.logger = logger or LOGGER

        self.output_directory = self._prepare_output_directory(
            output_directory, local_path, ckpt_iter
        )
        self.saved_data_names = saved_data_names
        self._load_checkpoint(ckpt_path, ckpt_iter)

    def _load_checkpoint(self, ckpt_path: str, ckpt_iter: str) -> None:
        """Load a checkpoint for the given neural network model."""
        ckpt_path = os.path.join(ckpt_path, self.local_path)
        if ckpt_iter == "max":
            ckpt_iter = find_max_epoch(ckpt_path)
        model_path = os.path.join(ckpt_path, f"{ckpt_iter}.pkl")
        try:
            checkpoint = torch.load(model_path, map_location="cpu")
            self.net.load_state_dict(checkpoint["model_state_dict"])
            self.logger.info(f"Successfully loaded model at iteration {ckpt_iter}")
        except FileNotFoundError as e:
            raise FileNotFoundError(f"Model file not found at {model_path}") from e
        except Exception as e:
            raise Exception(f"Failed to load model: {e}")

    def _prepare_output_directory(
        self, output_directory: str, local_path: str, ckpt_iter: str
    ) -> str:
        """Prepare the output directory to save generated samples."""
        ckpt_iter_str = (
            "max"
            if ckpt_iter == "max"
            else f"imputation_multiple_{int(ckpt_iter) // 1000}k"
        )
        output_directory = os.path.join(output_directory, local_path, ckpt_iter_str)
        os.makedirs(output_directory, exist_ok=True)
        os.chmod(output_directory, 0o775)
        self.logger.info(f"Output directory: {output_directory}")
        return output_directory

    def _update_mask(self, batch: torch.Tensor) -> torch.Tensor:
        """Update mask based on the given batch."""
        transposed_mask = MASK_FN[self.masking](batch[0], self.missing_k)
        return (
            transposed_mask.permute(1, 0)
            .repeat(batch.size()[0], 1, 1)
            .to(self.device, dtype=torch.float32)
        )

    def _save_data(
        self,
        results: Dict[str, np.ndarray],
        index: int,
    ) -> None:
        """Save generated_series, batch, and mask data arrays."""

        for name, data in results.items():
            if name in self.saved_data_names:
                filename = f"{name}{index}.npy"
                np.save(os.path.join(self.output_directory, filename), data)

    def _autoFRK_generate(
        self,
        sssd_inference,
        with_known_loc: bool = True
    ) -> torch.Tensor:
        LOGGER.info(f"Start autoFRK inference step")
        dtype = torch.float32
        device = sssd_inference.device
        sssd_inference = to_tensor(sssd_inference, dtype = dtype, device = device)

        N, T, V = sssd_inference.shape

        known_loc = to_tensor(np.load(self.known_location_path), dtype=dtype, device=device)
        unknown_loc = to_tensor(np.load(self.unknown_location_path), dtype=dtype, device=device)
        if known_loc.ndim == 1:
            known_loc = known_loc.view(-1, 1)
        if unknown_loc.ndim == 1:
            unknown_loc = unknown_loc.view(-1, 1)
        missing_loc = unknown_loc.shape[0]
        autoFRK_inference = torch.zeros((missing_loc, T, V), dtype=dtype, device=device)
        frk = AutoFRK(
            logger_level=30,
            dtype=dtype,
            device=device,
        )

        mrts = None
        for variable in tqdm(range(V), desc=f"inferencing autoFRK"):
            data_slice = sssd_inference[:, :, variable]
            try:
                _ = frk.forward(
                    data=data_slice,
                    loc=known_loc,
                    G = mrts,
                    method=self.AFRK_method,
                    tps_method=self.AFRK_tps_method,
                    requires_grad=False
                )
                pred = frk.predict(
                    newloc = unknown_loc
                )['pred.value']
                mrts = frk.obj['G'] if mrts is None else mrts

            except torch._C._LinAlgError:
                LOGGER.warning(f"Skipped variable={variable} due to ill-conditioned matrix")
                pred = torch.zeros((missing_loc, T), dtype=dtype, device=device)
            
            autoFRK_inference[:, :, variable] = pred

        if with_known_loc:
            autoFRK_inference = torch.cat([sssd_inference, autoFRK_inference], dim=0)

        return autoFRK_inference

    def generate(self) -> list:
        """Generate samples using the given neural network model."""
        all_generated = []
        for index, (batch,) in tqdm(enumerate(self.dataloader), total=len(self.dataloader), desc="inferencing sssd"):
            batch = batch.to(self.device)
            mask = self._update_mask(batch)
            batch = batch.permute(0, 2, 1)

            generated_series = (
                sampling(
                    net=self.net,
                    size=batch.shape,
                    diffusion_hyperparams=self.diffusion_hyperparams,
                    cond=batch,
                    mask=mask,
                    only_generate_missing=self.only_generate_missing,
                    device=self.device,
                )
            )
            all_generated.append(generated_series)
        sssd_inference = torch.cat(all_generated, dim=0).permute(0, 2, 1)
        ts_mean = to_tensor(self.ts_mean, dtype = torch.float32, device = sssd_inference.device)
        ts_std = to_tensor(self.ts_std, dtype = torch.float32, device = sssd_inference.device)
        sssd_inference = sssd_inference * ts_std + ts_mean

        if self.enable_spatial_prediction:
            autoFRK_inference = self._autoFRK_generate(
                sssd_inference = sssd_inference,
                with_known_loc = True
            )
            results = {'imputation': autoFRK_inference.detach().cpu().numpy()}
        else:
            results = {'imputation': sssd_inference.detach().cpu().numpy()}
        self._save_data(results, 0)

utils.py (無修改)

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os

import numpy as np


def read_multiple_imputations(folder_path: str, missing_k: int) -> np.ndarray:
    """
    Read multiple imputations generated from 'inference_multiples.py'.

    Args:
        folder_path (str): The folder containing the imputation files.
        missing_k (int): The number of the last elements to be predicted.

    Returns:
        np.ndarray: An array containing imputations with shape (num_files, obs, channel, missing_k).
    """
    # Check if the folder exists
    if not os.path.exists(folder_path):
        raise FileNotFoundError(f"Folder '{folder_path}' does not exist.")

    # Get a list of files in the folder
    file_list = os.listdir(folder_path)

    # Filter out only imputation0.npy files
    npy_files = [file for file in file_list if file.endswith("imputation0.npy")]

    if not npy_files:
        raise FileNotFoundError(f"No imputation0.npy files found in '{folder_path}'.")

    # Initialize stack array
    stack_array_data = []

    # Loop through all imputation0.npy files and read them
    for npy_file in npy_files:
        # shape = (obs, channel, length) -> (1, obs, channel, length)
        array_data = read_missing_k_data(folder_path, npy_file, missing_k)
        if array_data is not None:
            # Add a new axis for stacking
            array_data = np.expand_dims(array_data, axis=0)
            stack_array_data.append(array_data)

    if not stack_array_data:
        raise ValueError("No valid data found in the imputation files.")

    # Stack the arrays vertically
    stack_array_data = np.vstack(stack_array_data)
    return stack_array_data


def read_missing_k_data(folder_path: str, npy_file: str, missing_k: int) -> np.ndarray:
    """
    Read the last 'missing_k' elements of each observation from a NumPy file.

    Args:
        folder_path (str): The folder containing the file.
        npy_file (str): The file name to read.
        missing_k (int): The number of the last elements to be read.

    Returns:
        np.ndarray: An array containing the last 'missing_k' elements of each observation.
    """
    file_path = os.path.join(folder_path, npy_file)
    data = np.load(file_path)
    last_k_elements = data[:, :, (-missing_k):]
    return last_k_elements


def predict_interval(
    pred: np.ndarray, alpha: float = 0.05
) -> tuple[np.ndarray, np.ndarray]:
    """
    Compute the (1-alpha) quantile prediction interval of imputation ecdf.

    Args:
        pred (np.ndarray): All data with shape (num_imputations, obs, channel, length).
        alpha (float, optional): Significance level of the prediction interval. Defaults to 0.05.

    Returns:
        tuple[np.ndarray, np.ndarray]: Lower and upper bounds of the prediction interval with shape (obs, channel, length).
    """
    # Compute original prediction intervals
    L = np.quantile(pred, alpha / 2, axis=0)
    U = np.quantile(pred, 1 - alpha / 2, axis=0)

    return L, U


def compute_E_star(
    L: np.ndarray, U: np.ndarray, true: np.ndarray, alpha: float = 0.05
) -> np.ndarray:
    """
    Compute the (1-alpha) quantile of conformity scores, i.e., E_star.

    Args:
        L (np.ndarray): Lower bound to be adjusted with shape (obs, channel, length).
        U (np.ndarray): Upper bound to be adjusted with shape (obs, channel, length).
        true (np.ndarray): True values with shape (obs, channel, length).
        alpha (float, optional): Mis-coverage rate of conformal prediction. Defaults to 0.05.

    Returns:
        np.ndarray: E_star with shape (channel, length).
    """
    # Compute the conformity scores
    E = np.maximum(L - true, true - U)

    # Compute the (1-alpha) quantile of conformity scores
    CP_PAR = (1 + 1 / true.shape[0]) * (1 - alpha)
    E_star = np.quantile(E, CP_PAR, axis=0)
    return E_star


def adjust_PI(
    L: np.ndarray, U: np.ndarray, E_star: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
    """
    Adjust prediction interval using conformal prediction.

    Args:
        L (np.ndarray): Lower bound to be adjusted with shape (obs, channel, length).
        U (np.ndarray): Upper bound to be adjusted with shape (obs, channel, length).
        E_star (np.ndarray): Scores with shape (channel, length).

    Returns:
        tuple[np.ndarray, np.ndarray]: Adjusted lower and upper bound with shape (obs, channel, length).
    """
    E_star_exd = np.expand_dims(E_star, axis=0)
    adjusted_L = L - E_star_exd
    adjusted_U = U + E_star_exd
    return adjusted_L, adjusted_U


def coverage_rate(L: np.ndarray, U: np.ndarray, true: np.ndarray) -> np.ndarray:
    """
    Compute the coverage rate, which is the proportion of [L, U] containing true data.

    Args:
        L (np.ndarray): Lower bound with shape (obs, channel, length).
        U (np.ndarray): Upper bound with shape (obs, channel, length).
        true (np.ndarray): True data with shape (obs, channel, length).

    Returns:
        np.ndarray: Coverage rate with shape (1, length).
    """
    coverage = np.sum(np.logical_and(true > L, true < U), axis=0) / true.shape[0]
    return coverage

目前困境

autoFRK 並無需要迭代的參數,若直接將 autoFRK 接在 $SSSD^{S4}$ 後方,其效果與直接使用 SSSD 模型無異。若仍須將兩者結合,則仍須考慮將 FRK 模型接入 SSSD 的 training_loss 函數中。

參考資料