目錄

1141118 meeting

本次實驗著重於優化 SSSD + autoFRK 於訓練期間的效能,並再次翻新程式碼。更新後的程式碼如下:

主要修改

trainer.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
import logging
import os
from typing import Any, Dict, Optional, Union, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.checkpoint import checkpoint  # New
from tqdm import tqdm
import numpy as np  # New
import csv  # New

from sssd.core.model_specs import MASK_FN
from sssd.training.utils import training_loss
from sssd.data.utils import get_dataloader
from sssd.utils.logger import setup_logger
from sssd.utils.utils import find_max_epoch, std_normal  # New
from autoFRK import AutoFRK, to_tensor, garbage_cleaner  # New
from autoFRK.utils.helper import cbrt  # New

LOGGER = setup_logger()


class DiffusionTrainer:
    """
    Train Diffusion Models

    Args:
        dataloader (DataLoader): The training dataloader.
        diffusion_hyperparams (Dict[str, Any]): Hyperparameters for the diffusion process.
        net (nn.Module): The neural network model to be trained.
        device (torch.device): The device to be used for training.
        output_directory (str): Directory to save model checkpoints.
        ckpt_iter (Optional[int, str]): The checkpoint iteration to be loaded; 'max' selects the maximum iteration.
        n_iters (int): Number of iterations to train.
        iters_per_ckpt (int): Number of iterations to save checkpoint.
        iters_per_logging (int): Number of iterations to save training log and compute validation loss.
        learning_rate (float): Learning rate for training.
        only_generate_missing (int): Option to generate missing portions of the signal only.
        masking (str): Type of masking strategy: 'mnr' for Missing Not at Random, 'bm' for Blackout Missing, 'rm' for Random Missing.
        missing_k (int): K missing time steps for each feature across the sample length.
        batch_size (int): Size of each training batch.
        logger (Optional[logging.Logger]): Logger object for logging, defaults to None.
    """

    def __init__(
        self,
        data_path: str,
        diffusion_hyperparams: Dict[str, Any],
        net: nn.Module,
        device: Optional[Union[torch.device, str]],
        output_directory: str,
        ckpt_iter: Union[str, int],
        n_iters: int,
        iters_per_ckpt: int,
        iters_per_logging: int,
        learning_rate: float,
        only_generate_missing: int,
        masking: str,
        missing_k: int,
        batch_size: int,
        enable_spatial_prediction: bool,  # New
        autoFRK_period: int,  # New
        location_path: str,  # New
        AFRK_method: str,  # New
        AFRK_tps_method: str,  # New
        logger: Optional[logging.Logger] = None,
    ) -> None:
        loader, ts_mean, ts_std = get_dataloader(
            path=data_path,
            batch_size=batch_size,
            device=device,
        )
        self.dataloader = loader
        self.ts_mean = ts_mean
        self.ts_std = ts_std
        self.diffusion_hyperparams = diffusion_hyperparams
        self.net = nn.DataParallel(net).to(device)
        self.device = device
        self.output_directory = output_directory
        self.ckpt_iter = ckpt_iter
        self.n_iters = n_iters
        self.iters_per_ckpt = iters_per_ckpt
        self.iters_per_logging = iters_per_logging
        self.learning_rate = learning_rate
        self.only_generate_missing = only_generate_missing
        self.masking = masking
        self.missing_k = missing_k
        self.writer = SummaryWriter(f"{output_directory}/log")
        self.batch_size = batch_size
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
        self.enable_spatial_prediction = enable_spatial_prediction  # New
        self.autoFRK_period = autoFRK_period  # New
        self.location_path = location_path  # New
        self.real_data = self.dataloader.dataset.tensors[0].to(self.device) # * self.ts_std + self.ts_mean # New
        self.real_data_shape = self.real_data.shape  # New
        self.logger = logger or LOGGER
        self.AFRK_mrts = None  # New
        self.AFRK_method = AFRK_method  # New
        self.AFRK_tps_method = AFRK_tps_method  # New
        self.loc = to_tensor(np.load(self.location_path), dtype=self.real_data.dtype, device=device)

        frk = AutoFRK(
            logger_level=30,
            dtype=self.real_data.dtype,
            device=self.real_data.device,
        )
  
        # autoFRK
        def frk_step(data_slice, mrts):
            _ = frk.forward(
                data=data_slice,
                loc=self.loc,
                G=mrts,
                method=self.AFRK_method,
                tps_method=self.AFRK_tps_method,
                requires_grad=True
            )
            pred = frk.predict()['pred.value']
            return pred.T, frk.obj['G']
        #self.frk_step = torch.compile(frk_step)
        self.frk_step = frk_step

        if self.masking not in MASK_FN:
            raise KeyError(f"Please enter a correct masking, but got {self.masking}")

    def _load_checkpoint(self) -> None:
        if self.ckpt_iter == "max":
            self.ckpt_iter = find_max_epoch(self.output_directory)
        if self.ckpt_iter >= 0:
            try:
                model_path = os.path.join(
                    self.output_directory, f"{self.ckpt_iter}.pkl"
                )
                checkpoint = torch.load(model_path, map_location="cpu")

                self.net.load_state_dict(checkpoint["model_state_dict"])
                if "optimizer_state_dict" in checkpoint:
                    self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

                self.logger.info(
                    f"Successfully loaded model at iteration {self.ckpt_iter}"
                )
            except Exception as e:
                self.ckpt_iter = -1
                self.logger.error(f"No valid checkpoint model found. Error: {e}")
        else:
            self.ckpt_iter = -1
            self.logger.info(
                "No valid checkpoint model found, start training from initialization."
            )

    def _save_model(self, n_iter: int) -> None:
        if n_iter > 0 and n_iter % self.iters_per_ckpt == 0:
            torch.save(
                {
                    "model_state_dict": self.net.state_dict(),
                    "optimizer_state_dict": self.optimizer.state_dict(),
                },
                os.path.join(self.output_directory, f"{n_iter}.pkl"),
            )

    def _update_mask(self, batch: torch.Tensor) -> torch.Tensor:
        transposed_mask = MASK_FN[self.masking](batch[0], self.missing_k)
        return (
            transposed_mask.permute(1, 0)
            .repeat(batch.size()[0], 1, 1)
            .to(self.device, dtype=torch.float32)
        )
  
    def _sampling_differentiable(
        self,
        net: torch.nn.Module,
        size: Tuple[int, int, int],
        diffusion_hyperparams: Dict[str, torch.Tensor],
        cond: torch.Tensor,
        mask: torch.Tensor,
        only_generate_missing: int = 0,
        device: Union[torch.device, str] = "cpu",
        sampling_mode: str = "ddpm",  # ✅ "ddpm" 或 "ddim"
        step_skip: Union[int, None] = None  # ✅ 若為 None,則自動偵測
    ) -> torch.Tensor:
        _dh = diffusion_hyperparams
        T, Alpha, Alpha_bar, Sigma = _dh["T"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"]
        x = torch.randn(size, device=device)
        mask = mask.to(device)
        cond = cond.to(device)

        # ✅ 自動偵測最適 step_skip
        if step_skip is None:
            if T <= 200:
                step_skip = 1       # 小模型:全步精細推論
            elif T <= 1000:
                step_skip = max(1, T // 200)  # 一般模型:約 200 步內完成
            else:
                step_skip = max(1, T // 400)  # 大型模型:控制在 400 步以內
            #LOGGER.info(f"T={T}, using step_skip={step_skip}")
        step_skip = max(1, step_skip)

        # ✅ 建立跳步時間表
        timesteps = list(range(T - 1, -1, -step_skip))
        if timesteps[-1] != 0:
            timesteps.append(0)

        noise_list = []
        if sampling_mode == "ddpm":
            for _ in timesteps:
                # pre-generate noise with same shape as x; deterministic for forward+recompute
                noise_list.append(torch.randn(size, device=device))
        else:
            # keep placeholder list (not used for ddim)
            noise_list = [None] * len(timesteps)

        for idx, t in enumerate(timesteps):
            noise_t = noise_list[idx]
            def net_step(x_inner, noise=noise_t, tt=t):
                if only_generate_missing == 1:
                    x_inner = x_inner * (1 - mask).float() + cond * mask.float()
                diffusion_steps = (tt * torch.ones((size[0], 1), device=device))
                epsilon_theta = net((x_inner, cond, mask, diffusion_steps))

                # ✅ 根據 sampling_mode 決定取樣公式
                if sampling_mode == "ddpm":
                    # DDPM:隨機性版本
                    x_next = (x_inner - (1 - Alpha[tt]) / torch.sqrt(1 - Alpha_bar[tt]) * epsilon_theta) / torch.sqrt(Alpha[tt])
                    if tt > 0:
                        x_next = x_next + Sigma[tt] * noise
                elif sampling_mode == "ddim":
                    # DDIM:確定性版本,支援跳步
                    if tt == 0:
                        x_next = torch.sqrt(Alpha_bar[tt]) * (
                            (x_inner - torch.sqrt(1 - Alpha_bar[tt]) * epsilon_theta)
                            / torch.sqrt(Alpha_bar[tt])
                        )
                    else:
                        Alpha_bar_prev = Alpha_bar[max(tt - step_skip, 0)]
                        x0_pred = (x_inner - torch.sqrt(1 - Alpha_bar[tt]) * epsilon_theta) / torch.sqrt(Alpha_bar[tt])
                        x_next = torch.sqrt(Alpha_bar_prev) * x0_pred + torch.sqrt(1 - Alpha_bar_prev) * epsilon_theta
                else:
                    raise ValueError(f"Unknown sampling_mode: {sampling_mode}")

                return x_next
            x = checkpoint(net_step, x, use_reentrant=False)
            #x = net_step(x)

        return x
  
    def _sssd_prediction_step(self) -> torch.Tensor:
        self.logger.info(f"Start SSSD prediction step")

        all_batches = []
        for (batch,) in tqdm(self.dataloader, desc=f"{self.n_iter}-th predicting TS"):
            mask = self._update_mask(batch)
            batch = batch.permute(0, 2, 1)

            batch_generated = self._sampling_differentiable(
                net=self.net,
                size=batch.shape,
                diffusion_hyperparams=self.diffusion_hyperparams,
                cond=batch,
                mask=mask,
                only_generate_missing=self.only_generate_missing,
                device=self.device,
                sampling_mode=getattr(self, "sampling_mode", "ddim"),
                step_skip=getattr(self, "step_skip", 10)
            )
            all_batches.append(batch_generated)

        sssd_prediction = torch.cat(all_batches, dim=0).permute(1, 2, 0)

        return sssd_prediction

    def _autoFRK_step(self,
                      sssd_prediction,
                      ) -> torch.Tensor:
        # autoFRK
        self.logger.info(f"Start autoFRK inference step")
        dtype = sssd_prediction.dtype
        device = sssd_prediction.device
  
        V, T, N = sssd_prediction.shape
        autoFRK_result = torch.empty_like(sssd_prediction, dtype=dtype, device=device)
        frk_step = self.frk_step

        # mrts
        if self.AFRK_mrts is None:
            data_slice = sssd_prediction[0, :, :].T
            _, mrts = frk_step(data_slice, None)
            self.AFRK_mrts = mrts
        else:
            mrts = self.AFRK_mrts

        # Flatten 迴圈
        for idx in tqdm(range(V), desc=f"{self.n_iter}-th predicting autoFRK"):
            data_slice = sssd_prediction[idx, :, :].T
            try:
                pred, _ = frk_step(data_slice, mrts)
            except Exception as e:
                self.logger.warning(f"Variable {idx} numerical issue: {e}; using fallback clone")
                pred = data_slice.clone()
            autoFRK_result[idx, :, :] = pred
        nan_idx = torch.isnan(autoFRK_result)
        if nan_idx.any():
            self.logger.info(f"Replacing total {nan_idx.sum().item()} NaNs in autoFRK_result with original sssd_prediction values")
            autoFRK_result = torch.where(nan_idx, sssd_prediction, autoFRK_result)
        autoFRK_result = autoFRK_result.permute(2, 1, 0)

        # return
        if autoFRK_result.shape != self.real_data_shape:
            error_msg = f"Shape mismatch: autoFRK_result {autoFRK_result.shape} != real_data {self.real_data_shape}"
            self.logger.error(error_msg)
            raise ValueError(error_msg)

        return autoFRK_result

    def _train_per_epoch(self) -> torch.Tensor:
        loss_function=nn.MSELoss()
        #diffusion_loss = torch.tensor(0.0, device=self.device, dtype=torch.float32)
        #n_batches = 0

        # SSSD training
        for batch_idx, (batch,) in enumerate(tqdm(self.dataloader, desc=f"{self.n_iter}-th training TS")):
            batch = batch.to(self.device)

            # New
            mask = self._update_mask(batch)
            loss_mask = ~mask.bool()
            self.optimizer.zero_grad()

            if loss_mask.sum() == 0:
                self.logger.warning(f"Batch {batch_idx} has no valid elements for loss")
                continue

            batch = batch.permute(0, 2, 1)
            assert batch.size() == mask.size() == loss_mask.size()
  
            loss = training_loss(
                model=self.net,
                loss_function=loss_function,
                training_data=(batch, batch, mask, loss_mask),
                diffusion_parameters=self.diffusion_hyperparams,
                generate_only_missing=self.only_generate_missing,
                device=self.device,
            )

            if not (self.enable_spatial_prediction and self.n_iter % self.autoFRK_period == 0):
                loss.backward()
                self.optimizer.step()
                continue

            loss.backward(retain_graph=True)
            self.optimizer.step()
            #diffusion_loss = diffusion_loss + loss
            #n_batches += 1

        if self.enable_spatial_prediction and self.n_iter % self.autoFRK_period == 0:
            self.logger.info(f"Iteration {self.n_iter}: Start Spatial Prediction step")
            sssd_prediction = self._sssd_prediction_step()
            autoFRK_result = self._autoFRK_step(sssd_prediction=sssd_prediction)

            # compute loss
            self.optimizer.zero_grad()
            spatial_loss = loss_function(
                autoFRK_result,
                self.real_data
            )
            #loss = spatial_loss + diffusion_loss / n_batches
            loss = spatial_loss

            # update model
            loss.backward()
            self.optimizer.step()
  
        self.logger.info(f"Iteration {self.n_iter}: loss: {loss.item()}")
  
        return loss

    def train(self) -> None:
        self._load_checkpoint()

        n_iter_start = (
            self.ckpt_iter + 2 if self.ckpt_iter == -1 else self.ckpt_iter + 1
        )
        self.logger.info(f"Start the {n_iter_start} iteration")

        for n_iter in range(n_iter_start, self.n_iters + 1):
            self.n_iter = n_iter
            loss = self._train_per_epoch()
            self.writer.add_scalar("Train/Loss", loss.item(), n_iter)
            if n_iter % self.iters_per_logging == 0:
                self.logger.info(f"Iteration: {n_iter} \tLoss: { loss.item()}")
            self._save_model(n_iter)

generator.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import logging
import os
from typing import Dict, Iterable, Optional, Union

import numpy as np
import torch
from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error
from torch.utils.data import DataLoader
from tqdm import tqdm  # New

from sssd.core.model_specs import MASK_FN
from sssd.data.utils import get_dataloader  # New
from sssd.utils.logger import setup_logger
from sssd.utils.utils import find_max_epoch, sampling
from autoFRK import AutoFRK, to_tensor, garbage_cleaner  # New

LOGGER = setup_logger()


class DiffusionGenerator:
    """
    Generate data based on ground truth.

    Args:
        net (torch.nn.Module): The neural network model.
        device (Optional[Union[torch.device, str]]): The device to run the model on (e.g., 'cuda' or 'cpu').
        diffusion_hyperparams (dict): Dictionary of diffusion hyperparameters.
        local_path (str): Local path format for the model.
        testing_data (torch.Tensor): Tensor containing testing data.
        output_directory (str): Path to save generated samples.
        batch_size (int): Number of samples to generate.
        ckpt_path (str): Checkpoint directory.
        ckpt_iter (str): Pretrained checkpoint to load; 'max' selects the maximum iteration.
        masking (str): Type of masking: 'mnr' (missing not at random), 'bm' (black-out), 'rm' (random missing).
        missing_k (int): Number of missing time points for each channel across the length.
        only_generate_missing (int): Whether to generate only missing portions of the signal:
                                      - 0 (all sample diffusion),
                                      - 1 (generate missing portions only).
        saved_data_names (Iterable[str], optional): Names of data arrays to save (default is ("imputation", "original", "mask")).
        logger (Optional[logging.Logger], optional): Logger object for logging messages (default is None).
    """

    def __init__(
        self,
        net: torch.nn.Module,
        device: Optional[Union[torch.device, str]],
        diffusion_hyperparams: dict,
        local_path: str,
        data_path: str,
        output_directory: str,
        batch_size: int,
        ckpt_path: str,
        ckpt_iter: str,
        masking: str,
        missing_k: int,
        only_generate_missing: int,
        enable_spatial_prediction: bool,  # New
        known_location_path: str,  # New
        unknown_location_path: str,  # New
        AFRK_method: str,  # New
        AFRK_tps_method: str,  # New
        saved_data_names: Iterable[str] = ("imputation", "original", "mask"),
        logger: Optional[logging.Logger] = None,
    ):
        self.net = net
        self.device = device
        self.diffusion_hyperparams = diffusion_hyperparams
        self.local_path = local_path
        loader, ts_mean, ts_std = get_dataloader(
            path=data_path,
            batch_size=batch_size,
            device=device,
            inference=True,
            missing_k=missing_k
        )
        self.dataloader = loader
        self.ts_mean = ts_mean
        self.ts_std = ts_std
        self.batch_size = batch_size
        self.masking = masking
        self.missing_k = missing_k
        self.only_generate_missing = only_generate_missing
        self.enable_spatial_prediction = enable_spatial_prediction  # New
        self.known_location_path = known_location_path  # New
        self.unknown_location_path = unknown_location_path  # New
        self.missing_k = missing_k  # New
        self.AFRK_method = AFRK_method  # New
        self.AFRK_tps_method = AFRK_tps_method # New
        self.logger = logger or LOGGER

        self.output_directory = self._prepare_output_directory(
            output_directory, local_path, ckpt_iter
        )
        self.saved_data_names = saved_data_names
        self._load_checkpoint(ckpt_path, ckpt_iter)

    def _load_checkpoint(self, ckpt_path: str, ckpt_iter: str) -> None:
        """Load a checkpoint for the given neural network model."""
        ckpt_path = os.path.join(ckpt_path, self.local_path)
        if ckpt_iter == "max":
            ckpt_iter = find_max_epoch(ckpt_path)
        model_path = os.path.join(ckpt_path, f"{ckpt_iter}.pkl")
        try:
            checkpoint = torch.load(model_path, map_location="cpu")
            self.net.load_state_dict(checkpoint["model_state_dict"])
            self.logger.info(f"Successfully loaded model at iteration {ckpt_iter}")
        except FileNotFoundError as e:
            raise FileNotFoundError(f"Model file not found at {model_path}") from e
        except Exception as e:
            raise Exception(f"Failed to load model: {e}")

    def _prepare_output_directory(
        self, output_directory: str, local_path: str, ckpt_iter: str
    ) -> str:
        """Prepare the output directory to save generated samples."""
        ckpt_iter_str = (
            "max"
            if ckpt_iter == "max"
            else f"imputation_multiple_{int(ckpt_iter) // 1000}k"
        )
        output_directory = os.path.join(output_directory, local_path, ckpt_iter_str)
        os.makedirs(output_directory, exist_ok=True)
        os.chmod(output_directory, 0o775)
        self.logger.info(f"Output directory: {output_directory}")
        return output_directory

    def _update_mask(self, batch: torch.Tensor) -> torch.Tensor:
        """Update mask based on the given batch."""
        transposed_mask = MASK_FN[self.masking](batch[0], self.missing_k)
        return (
            transposed_mask.permute(1, 0)
            .repeat(batch.size()[0], 1, 1)
            .to(self.device, dtype=torch.float32)
        )

    def _save_data(
        self,
        results: Dict[str, np.ndarray],
        index: int,
    ) -> None:
        """Save generated_series, batch, and mask data arrays."""

        for name, data in results.items():
            if name in self.saved_data_names:
                filename = f"{name}{index}.npy"
                np.save(os.path.join(self.output_directory, filename), data)

    def _autoFRK_generate(
        self,
        sssd_inference,
        with_known_loc: bool = True
    ) -> torch.Tensor:
        LOGGER.info(f"Start autoFRK inference step")
        dtype = torch.float32
        device = sssd_inference.device
        sssd_inference = to_tensor(sssd_inference, dtype = dtype, device = device)

        N, T, V = sssd_inference.shape

        known_loc = to_tensor(np.load(self.known_location_path), dtype=dtype, device=device)
        unknown_loc = to_tensor(np.load(self.unknown_location_path), dtype=dtype, device=device)
        if known_loc.ndim == 1:
            known_loc = known_loc.view(-1, 1)
        if unknown_loc.ndim == 1:
            unknown_loc = unknown_loc.view(-1, 1)
        missing_loc = unknown_loc.shape[0]
        autoFRK_inference = torch.zeros((missing_loc, T, V), dtype=dtype, device=device)
        print(f"missing_loc : {missing_loc}")
        print(f"autoFRK_inference : {autoFRK_inference.shape}")
        frk = AutoFRK(
            logger_level=30,
            dtype=dtype,
            device=device,
        )

        mrts = None
        for variable in tqdm(range(V), desc=f"inferencing autoFRK"):
            data_slice = sssd_inference[:, :, variable]
            try:
                _ = frk.forward(
                    data=data_slice,
                    loc=known_loc,
                    G = mrts,
                    method=self.AFRK_method,
                    tps_method=self.AFRK_tps_method,
                    requires_grad=False
                )
                pred = frk.predict(
                    newloc = unknown_loc
                )['pred.value']
                mrts = frk.obj['G'] if mrts is None else mrts

            except torch._C._LinAlgError:
                LOGGER.warning(f"Skipped variable={variable} due to ill-conditioned matrix")
                pred = torch.zeros((missing_loc, T), dtype=dtype, device=device)
            
            print(f"pred : {pred.shape}")
            print(f"autoFRK_inference[:, :, variable]  : {autoFRK_inference[:, :, variable] .shape}")
            autoFRK_inference[:, :, variable] = pred

        if with_known_loc:
            autoFRK_inference = torch.cat([sssd_inference, autoFRK_inference], dim=0)

        return autoFRK_inference

    def generate(self) -> list:
        """Generate samples using the given neural network model."""
        all_generated = []
        for index, (batch,) in tqdm(enumerate(self.dataloader), total=len(self.dataloader), desc="inferencing sssd"):
            batch = batch.to(self.device)
            mask = self._update_mask(batch)
            batch = batch.permute(0, 2, 1)

            generated_series = (
                sampling(
                    net=self.net,
                    size=batch.shape,
                    diffusion_hyperparams=self.diffusion_hyperparams,
                    cond=batch,
                    mask=mask,
                    only_generate_missing=self.only_generate_missing,
                    device=self.device,
                )
            )
            all_generated.append(generated_series)
        sssd_inference = torch.cat(all_generated, dim=0).permute(0, 2, 1)
        ts_mean = to_tensor(self.ts_mean, dtype = torch.float32, device = sssd_inference.device)
        ts_std = to_tensor(self.ts_std, dtype = torch.float32, device = sssd_inference.device)
        sssd_inference = sssd_inference * ts_std + ts_mean

        if self.enable_spatial_prediction:
            autoFRK_inference = self._autoFRK_generate(
                sssd_inference = sssd_inference,
                with_known_loc = True
            )
            results = {'imputation': autoFRK_inference.detach().cpu().numpy()}
        else:
            results = {'imputation': sssd_inference.detach().cpu().numpy()}
        self._save_data(results, 0)
  • weather2k-fast-rectangular

    MetricALL Locs & All TimeKnown Locs & All TimeUnknown Locs & All TimeALL Locs & FutureKnown Locs & FutureUnknown Locs & FutureALL Locs & PastKnown Locs & PastUnknown Locs & Past
    MSPE6.824578e+006.900100e+006.523295e+006.818313e+006.641201e+007.524866e+006.826123e+006.963984e+006.276154e+00
    RMSPE2.612389e+002.626804e+002.554074e+002.611190e+002.577053e+002.743149e+002.612685e+002.638936e+002.505225e+00
    MSPE%1.334338e+101.351421e+101.266188e+102.100891e+102.423505e+108.138878e+091.145188e+101.086881e+101.377794e+10
    RMSPE%1.155135e+051.162506e+051.125250e+051.449445e+051.556761e+059.021573e+041.070135e+051.042536e+051.173795e+05
    MAPE1.851068e+001.860901e+001.811842e+001.853678e+001.828216e+001.955250e+001.850424e+001.868966e+001.776456e+00
    MAPE%4.196054e+093.992812e+095.006849e+095.329069e+095.366773e+095.178660e+093.916479e+093.653783e+094.964454e+09
  • weather2k-fast-spherical_fast

    MetricALL Locs & All TimeKnown Locs & All TimeUnknown Locs & All TimeALL Locs & FutureKnown Locs & FutureUnknown Locs & FutureALL Locs & PastKnown Locs & PastUnknown Locs & Past
    MSPE6.926058e+007.029210e+006.514549e+006.847387e+006.721686e+007.348849e+006.945470e+007.105093e+006.308683e+00
    RMSPE2.631740e+002.651266e+002.552362e+002.616751e+002.592621e+002.710876e+002.635426e+002.665538e+002.511709e+00
    MSPE%1.298222e+101.329351e+101.174043e+101.795051e+102.071103e+106.937951e+091.175629e+101.146321e+101.292546e+10
    RMSPE%1.139396e+051.152975e+051.083533e+051.339795e+051.439133e+058.329436e+041.084264e+051.070664e+051.136902e+05
    MAPE1.869697e+001.884203e+001.811828e+001.865301e+001.847785e+001.935177e+001.870782e+001.893189e+001.781392e+00
    MAPE%4.007189e+093.822311e+094.744721e+094.871713e+094.917260e+094.690014e+093.793864e+093.552129e+094.758220e+09
  • weather2k-sssds4-fast-rectangular

    MetricALL Locs & All TimeKnown Locs & All TimeUnknown Locs & All TimeALL Locs & FutureKnown Locs & FutureUnknown Locs & FutureALL Locs & PastKnown Locs & PastUnknown Locs & Past
    MSPE6.854971e+006.985919e+006.332583e+006.770085e+006.654324e+007.231891e+006.875918e+007.067741e+006.110676e+00
    RMSPE2.618200e+002.643089e+002.516462e+002.601939e+002.579598e+002.689218e+002.622197e+002.658522e+002.471978e+00
    MSPE%1.343041e+101.377710e+101.204736e+101.953017e+102.253368e+107.548256e+091.192528e+101.161639e+101.315753e+10
    RMSPE%1.158897e+051.173759e+051.097605e+051.397504e+051.501122e+058.688070e+041.092029e+051.077794e+051.147063e+05
    MAPE1.855692e+001.874323e+001.781368e+001.850206e+001.835460e+001.909031e+001.857046e+001.883913e+001.749867e+00
    MAPE%4.098187e+093.900545e+094.886644e+095.146630e+095.190571e+094.971337e+093.839481e+093.582227e+094.865745e+09
  • weather2k-sssds4-fast-spherical_fast

    MetricALL Locs & All TimeKnown Locs & All TimeUnknown Locs & All TimeALL Locs & FutureKnown Locs & FutureUnknown Locs & FutureALL Locs & PastKnown Locs & PastUnknown Locs & Past
    MSPE6.913181e+007.042152e+006.398676e+006.796332e+006.716533e+007.114675e+006.942014e+007.122500e+006.222000e+00
    RMSPE2.629293e+002.653705e+002.529560e+002.606977e+002.591627e+002.667335e+002.634770e+002.668801e+002.494394e+00
    MSPE%1.281252e+101.299316e+101.209187e+101.925186e+102.218335e+107.557259e+091.122359e+101.072546e+101.321080e+10
    RMSPE%1.131924e+051.139876e+051.099630e+051.387511e+051.489407e+058.693250e+041.059414e+051.035638e+051.149382e+05
    MAPE1.862306e+001.881270e+001.786649e+001.855205e+001.844276e+001.898800e+001.864058e+001.890399e+001.758975e+00
    MAPE%4.030800e+093.818047e+094.879537e+095.100416e+095.145186e+094.921815e+093.766869e+093.490571e+094.869104e+09

優化部分

目前,結合先前與本週實驗發現,程式碼部分已做以下優化。

以下實驗均以 Weather2K 資料集作為訓練資料,並取其 6 變數作為模型訓練與測試,包含使用 384 個時間點進行訓練,使用 96 個時間點做為測試,其中包含 19 個未知時間。同時,已知地點共計採用 1,492 個地點,未知地點採用 374 個。

優化項目優化前優化後優化方法
SSSD 訓練過程預測約 60 分鐘約 90 秒反向時進行跳步,每執行 1 步後,跳過 10 步
autoFRK 填補已知地點約 70 分鐘約 10 秒使資料切片從原先僅支援 1 個時間點,改為支援 1 個批次資料 (N, T)

目前, SSSD 在訓練過程中為輸出給 autoFRK 進行運算,需要將完整資料先行計算完畢,再傳予 autoFRK 。由於 autoFRK 充斥大量矩陣運算,傳入完整資料集會造成矩陣龐大, GPU 的 VRAM 會吃不消,也會造成算力的浪費,同時造成在進行 backward 更新模型參數時會延宕許久(約 4 - 5 分鐘)。

但由於目前已將 autoFRK 修復支援批次資料輸入,故可以將以上運行方式改為 SSSD 批次訓練 → SSSD 批次預測 → autoFRK 批次產生結果 → 依批次產生 loss → backward 調整梯度與模型參數。 由於是以批次運算進行,故GPU 使用率應可大幅下降,且應可加速計算時間。如改寫順利,應可於往後會議中提出。

autoFRK 的優化

先前提到,本次實驗優化 autoFRK ,使之恢復批次輸入能力。原 R package 便包含此問題,使得後續以此為基礎所撰寫的 Python 模組也出現類似問題。經簡單修改後, MSE 比較如下:

資料大小修復前修復後
(N, 1)數值極小同修復前
(N, 2+)數值差異極大,甚至達 (N, 1) 的數萬倍以上同 (N, 1) 數值

經修復後,若於 autoFRK 使用批次輸入,其 MSE 數值與使用單一資料切片並使用迴圈方式取值相同,僅從在數值上的些微差異,應可認為是矩陣進行特徵值運算,抑或是數值誤差。但在時間成本上,使用資料切片與資料批次的時間成本卻是巨大,以本次實驗為例,更可達 14 倍差距。

但經過測試後,發現在 "EM" 方法下,使用批次輸入的效果不一定會較使用單一資料切片的方式優。故推薦如需使用批次資料輸入的方式,應優先使用 "fast" 方法,或對 "EM" 方法進行調整,如增加迭代次數等,以達最佳填補效果。

本次修改的 autoFRK Python 版本如下:

/src/autoFRK/utils/estimator.pycMLEimat 函數,將

1
2
3
4
5
6
7
    if ncol_Fk > 2:
        reduced_columns = torch.cat([
            torch.tensor([0], dtype=torch.int64, device=device),
            (d_hat[1:(ncol_Fk - 1)] > 0).nonzero(as_tuple=True)[0]
        ])
    else:
        reduced_columns = torch.tensor([ncol_Fk - 1], dtype=torch.int64, device=device)

替換為

1
2
3
4
5
6
7
    if ncol_Fk > 2:
        reduced_columns = torch.unique(torch.cat([
            torch.tensor([0], dtype=torch.int64, device=device),
            (d_hat[1:(ncol_Fk - 1)] > 0).nonzero(as_tuple=True)[0]
        ]))
    else:
        reduced_columns = torch.tensor([ncol_Fk - 1], dtype=torch.int64, device=device)

修改的 autoFRK R 版本如下:

/R/estimator.RcMLEimat 函數,將

1
2
3
4
5
if (ncol_Fk > 2) {
      reduced_columns <- c(1, which(d_hat[2:ncol_Fk] > 0))
    } else {
      reduced_columns <- ncol_Fk
    }

替換為

1
2
3
4
5
if (ncol_Fk > 2) {
      reduced_columns <- unique(c(1, which(d_hat[2:ncol_Fk] > 0)))
    } else {
      reduced_columns <- ncol_Fk
    }

此處是針對基底函數行數大於 2 時,僅保留非重複行,應會避免後續計算時的 Rank 問題。目前尚未發現其他可能的影響因子,而以上修復已記錄於以下儲存庫中:

參考資料