目錄

1141104 meeting

更新與瓶頸

本次更新針對 autoFRK-python 相關程式碼進行優化,以更完善地支援梯度計算,從而能與原本的 SSSD 程式碼整合。

目前在使用 GPU 進行運算時,autoFRK 模組尚不支援 batch 輸入,僅能針對單一時間點進行計算。
因此必須透過 for-loop 逐步處理,導致:

  • 計算時間大幅延長:約從 9 秒增加至 70 分鐘
  • GPU 利用率嚴重下降:從 100% 掉至接近 0%

先前在 SSSD 訓練階段中,使用 sampling 方法進行預測時曾遇到效率不佳的情況。
透過 降低噪音還原步驟數量 T,已成功大幅縮短預測時間。

目前整體訓練瓶頸僅剩 autoFRK 的填補階段,但目前嘗試多種辦法皆無法有效解決此問題(如:torch.compile 、 pipeline 、平行化等)。

本次主要程式碼如下:

trainer.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import logging
import os
from typing import Any, Dict, Optional, Union, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.checkpoint import checkpoint  # New
from tqdm import tqdm
import numpy as np  # New

from sssd.core.model_specs import MASK_FN
from sssd.training.utils import training_loss
from sssd.utils.logger import setup_logger
from sssd.utils.utils import find_max_epoch, std_normal  # New
from autoFRK import autoFRK, to_tensor, garbage_cleaner  # New

LOGGER = setup_logger()


class DiffusionTrainer:
    """
    Train Diffusion Models

    Args:
        dataloader (DataLoader): The training dataloader.
        diffusion_hyperparams (Dict[str, Any]): Hyperparameters for the diffusion process.
        net (nn.Module): The neural network model to be trained.
        device (torch.device): The device to be used for training.
        output_directory (str): Directory to save model checkpoints.
        ckpt_iter (Optional[int, str]): The checkpoint iteration to be loaded; 'max' selects the maximum iteration.
        n_iters (int): Number of iterations to train.
        iters_per_ckpt (int): Number of iterations to save checkpoint.
        iters_per_logging (int): Number of iterations to save training log and compute validation loss.
        learning_rate (float): Learning rate for training.
        only_generate_missing (int): Option to generate missing portions of the signal only.
        masking (str): Type of masking strategy: 'mnr' for Missing Not at Random, 'bm' for Blackout Missing, 'rm' for Random Missing.
        missing_k (int): K missing time steps for each feature across the sample length.
        batch_size (int): Size of each training batch.
        logger (Optional[logging.Logger]): Logger object for logging, defaults to None.
    """

    def __init__(
        self,
        dataloader: DataLoader,
        diffusion_hyperparams: Dict[str, Any],
        net: nn.Module,
        device: Optional[Union[torch.device, str]],
        output_directory: str,
        ckpt_iter: Union[str, int],
        n_iters: int,
        iters_per_ckpt: int,
        iters_per_logging: int,
        learning_rate: float,
        only_generate_missing: int,
        masking: str,
        missing_k: int,
        batch_size: int,
        enable_spatial_prediction: bool,  # New
        n_cores: Union[int, str],  # New
        autoFRK_period: int,  # New
        location_path: str,  # New
        logger: Optional[logging.Logger] = None,
    ) -> None:
        self.dataloader = dataloader
        self.diffusion_hyperparams = diffusion_hyperparams
        self.net = nn.DataParallel(net).to(device)
        self.device = device
        self.output_directory = output_directory
        self.ckpt_iter = ckpt_iter
        self.n_iters = n_iters
        self.iters_per_ckpt = iters_per_ckpt
        self.iters_per_logging = iters_per_logging
        self.learning_rate = learning_rate
        self.only_generate_missing = only_generate_missing
        self.masking = masking
        self.missing_k = missing_k
        self.writer = SummaryWriter(f"{output_directory}/log")
        self.batch_size = batch_size
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
        self.enable_spatial_prediction = enable_spatial_prediction  # New
        self.n_cores = n_cores  # New
        self.autoFRK_period = autoFRK_period  # New
        self.location_path = location_path  # New
        self.real_data = self.dataloader.dataset.tensors[0].to(self.device)  # New
        self.real_data_shape = self.real_data.shape  # New
        self.logger = logger or LOGGER

        if self.masking not in MASK_FN:
            raise KeyError(f"Please enter a correct masking, but got {self.masking}")

    def _load_checkpoint(self) -> None:
        if self.ckpt_iter == "max":
            self.ckpt_iter = find_max_epoch(self.output_directory)
        if self.ckpt_iter >= 0:
            try:
                model_path = os.path.join(
                    self.output_directory, f"{self.ckpt_iter}.pkl"
                )
                checkpoint = torch.load(model_path, map_location="cpu")

                self.net.load_state_dict(checkpoint["model_state_dict"])
                if "optimizer_state_dict" in checkpoint:
                    self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

                self.logger.info(
                    f"Successfully loaded model at iteration {self.ckpt_iter}"
                )
            except Exception as e:
                self.ckpt_iter = -1
                self.logger.error(f"No valid checkpoint model found. Error: {e}")
        else:
            self.ckpt_iter = -1
            self.logger.info(
                "No valid checkpoint model found, start training from initialization."
            )

    def _save_model(self, n_iter: int) -> None:
        if n_iter > 0 and n_iter % self.iters_per_ckpt == 0:
            torch.save(
                {
                    "model_state_dict": self.net.state_dict(),
                    "optimizer_state_dict": self.optimizer.state_dict(),
                },
                os.path.join(self.output_directory, f"{n_iter}.pkl"),
            )

    def _update_mask(self, batch: torch.Tensor) -> torch.Tensor:
        transposed_mask = MASK_FN[self.masking](batch[0], self.missing_k)
        return (
            transposed_mask.permute(1, 0)
            .repeat(batch.size()[0], 1, 1)
            .to(self.device, dtype=torch.float32)
        )
    
    def _sampling_differentiable(
        self,
        net: torch.nn.Module,
        size: Tuple[int, int, int],
        diffusion_hyperparams: Dict[str, torch.Tensor],
        cond: torch.Tensor,
        mask: torch.Tensor,
        only_generate_missing: int = 0,
        device: Union[torch.device, str] = "cpu",
        sampling_mode: str = "ddpm",  # ✅ "ddpm" 或 "ddim"
        step_skip: Union[int, None] = None  # ✅ 若為 None,則自動偵測
    ) -> torch.Tensor:
        _dh = diffusion_hyperparams
        T, Alpha, Alpha_bar, Sigma = _dh["T"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"]
        x = torch.randn(size, device=device)
        mask = mask.to(device)
        cond = cond.to(device)

        # ✅ 自動偵測最適 step_skip
        if step_skip is None:
            if T <= 200:
                step_skip = 1       # 小模型:全步精細推論
            elif T <= 1000:
                step_skip = max(1, T // 200)  # 一般模型:約 200 步內完成
            else:
                step_skip = max(1, T // 400)  # 大型模型:控制在 400 步以內
            #LOGGER.info(f"T={T}, using step_skip={step_skip}")

        # ✅ 建立跳步時間表
        timesteps = list(range(T - 1, -1, -step_skip))
        if timesteps[-1] != 0:
            timesteps.append(0)

        for t in timesteps:
            def net_step(x_inner):
                if only_generate_missing == 1:
                    x_inner = x_inner * (1 - mask).float() + cond * mask.float()
                diffusion_steps = (t * torch.ones((size[0], 1), device=device))
                epsilon_theta = net((x_inner, cond, mask, diffusion_steps))

                # ✅ 根據 sampling_mode 決定取樣公式
                if sampling_mode == "ddpm":
                    # DDPM:隨機性版本
                    x_next = (x_inner - (1 - Alpha[t]) / torch.sqrt(1 - Alpha_bar[t]) * epsilon_theta) / torch.sqrt(Alpha[t])
                    if t > 0:
                        x_next = x_next + Sigma[t] * torch.randn_like(x_next)
                elif sampling_mode == "ddim":
                    # DDIM:確定性版本,支援跳步
                    if t == 0:
                        x_next = torch.sqrt(Alpha_bar[t]) * (
                            (x_inner - torch.sqrt(1 - Alpha_bar[t]) * epsilon_theta)
                            / torch.sqrt(Alpha_bar[t])
                        )
                    else:
                        Alpha_bar_prev = Alpha_bar[max(t - step_skip, 0)]
                        x0_pred = (x_inner - torch.sqrt(1 - Alpha_bar[t]) * epsilon_theta) / torch.sqrt(Alpha_bar[t])
                        x_next = torch.sqrt(Alpha_bar_prev) * x0_pred + torch.sqrt(1 - Alpha_bar_prev) * epsilon_theta
                else:
                    raise ValueError(f"Unknown sampling_mode: {sampling_mode}")

                return x_next

            # ✅ checkpoint 保留梯度可逆性
            x = checkpoint(net_step, x, use_reentrant=False)

        return x
    
    def _sssd_prediction_step(self) -> torch.Tensor:
        LOGGER.info(f"Start SSSD prediction step")

        all_batches = []
        for (batch,) in tqdm(self.dataloader, desc=f"{self.n_iter}-th predicting TS"):
            mask = self._update_mask(batch)
            batch = batch.permute(0, 2, 1)

            # ✅ 新增:自動 step_skip 與 sampling_mode
            batch_generated = self._sampling_differentiable(
                net=self.net,
                size=batch.shape,
                diffusion_hyperparams=self.diffusion_hyperparams,
                cond=batch,
                mask=mask,
                only_generate_missing=self.only_generate_missing,
                device=self.device,
                sampling_mode=getattr(self, "sampling_mode", "ddpm"),
                step_skip=getattr(self, "step_skip", 10)  # 若未指定則自動偵測
            )

            all_batches.append(batch_generated)
            garbage_cleaner()

        sssd_prediction = torch.cat(all_batches, dim=0).permute(1, 2, 0)
        return sssd_prediction

    def _autoFRK_step(self,
                      sssd_prediction,
                      ) -> torch.Tensor:
        # autoFRK
        LOGGER.info(f"Start autoFRK inference step")
        dtype = sssd_prediction.dtype
        device = sssd_prediction.device
        V, T, N = sssd_prediction.shape
        total_steps = V * T

        loc = to_tensor(np.load(self.location_path), dtype=dtype, device=device)
        autoFRK_result = torch.empty_like(sssd_prediction, dtype=dtype, device=device)
        frk = autoFRK(
            logger_level=30,
            dtype=dtype,
            device=device,
        )

        # 只在第一次初始化時編譯 autoFRK
        # if not hasattr(self, "_compiled_frk") or self._compiled_frk is None:
        #     frk = autoFRK(
        #         logger_level=30,
        #         dtype=dtype,
        #         device=device,
        #     )
        #     self._compiled_frk = torch.compile(frk, mode="reduce-overhead", dynamic=True)
        # LOGGER.info("Compiled autoFRK for the first time with dynamic=True.")
        # frk = self._compiled_frk

        mrts = None
        with tqdm(total=total_steps, desc=f"{self.n_iter}-th predicting autoFRK") as pbar:
            for variable in range(sssd_prediction.shape[0]):
                for time_index in range(sssd_prediction.shape[1]):
                    data_slice = sssd_prediction[variable, time_index, :]

                    try:
                        _ = frk.forward(
                            data=data_slice,
                            loc=loc,
                            G = mrts,
                            method="fast",
                            tps_method="rectangular",
                            requires_grad=True
                        )
                        pred = frk.predict()['pred.value']

                    except torch._C._LinAlgError:
                        LOGGER.warning(f"Skipped variable={variable}, time={time_index} due to ill-conditioned matrix")
                        pred = torch.zeros((N, 1), dtype=dtype, device=device, requires_grad=True)
                        mrts = frk.obj['G']['MRTS'] if mrts is None else mrts
                    
                    autoFRK_result[variable, time_index, :] = pred.T
                    pbar.update(1)
        autoFRK_result = autoFRK_result.permute(2, 1, 0)

        # return
        if autoFRK_result.shape != self.real_data_shape:
            error_msg = f"Shape mismatch: autoFRK_result {autoFRK_result.shape} != real_data {self.real_data_shape}"
            LOGGER.error(error_msg)
            raise ValueError(error_msg)

        return autoFRK_result

    def _train_per_epoch(self) -> torch.Tensor:

        # SSSD training
        for (batch,) in tqdm(self.dataloader, desc=f"{self.n_iter}-th   training TS"):
            batch = batch.to(self.device)
            mask = self._update_mask(batch)
            loss_mask = ~mask.bool()
            loss_function=nn.MSELoss()

            batch = batch.permute(0, 2, 1)
            assert batch.size() == mask.size() == loss_mask.size()

            self.optimizer.zero_grad()
            loss = training_loss(
                model=self.net,
                loss_function=loss_function,
                training_data=(batch, batch, mask, loss_mask),
                diffusion_parameters=self.diffusion_hyperparams,
                generate_only_missing=self.only_generate_missing,
                device=self.device,
            )
            loss.backward()
            self.optimizer.step()

        if self.enable_spatial_prediction and self.n_iter % self.autoFRK_period == 0:
            LOGGER.info(f"Iteration {self.n_iter}: Start Spatial Prediction step")
            sssd_prediction = self._sssd_prediction_step()
            garbage_cleaner()
            autoFRK_result = self._autoFRK_step(sssd_prediction=sssd_prediction)

            # compute loss
            self.optimizer.zero_grad()
            loss = loss_function(
                autoFRK_result,
                self.real_data
            )

            # update model
            loss.backward()
            self.optimizer.step()
            
            LOGGER.info(f"Iteration {self.n_iter}: Spatial Prediction step done, loss: {loss.item()}")

        return loss

    def train(self) -> None:
        self._load_checkpoint()

        n_iter_start = (
            self.ckpt_iter + 2 if self.ckpt_iter == -1 else self.ckpt_iter + 1
        )
        self.logger.info(f"Start the {n_iter_start} iteration")

        for n_iter in range(n_iter_start, self.n_iters + 1):
            self.n_iter = n_iter
            loss = self._train_per_epoch()
            self.writer.add_scalar("Train/Loss", loss.item(), n_iter)
            if n_iter % self.iters_per_logging == 0:
                self.logger.info(f"Iteration: {n_iter} \tLoss: { loss.item()}")
            self._save_model(n_iter)

方法與結論

此次預測結果如下,訓練階段皆使用 SSSD + autoFRK 於 Python 完成,推論階段於 R 完成。

MetricALL Locs & All TimeKnown Locs & All TimeUnknown Locs & All TimeALL Locs & FutureKnown Locs & FutureUnknown Locs & FutureALL Locs & PastKnown Locs & PastUnknown Locs & Past
MSPE5.784901e+005.827523e+005.614871e+009.930628e+009.855177e+001.023162e+015.614036e+005.661524e+005.424592e+00
RMSPE2.405182e+002.414026e+002.369572e+003.151290e+003.139296e+003.198691e+002.369396e+002.379396e+002.329075e+00
MSPE%1.198381e+069.619680e+052.141505e+061.938496e+062.069548e+061.415692e+061.167877e+069.163194e+052.171419e+06
RMSPE%1.094706e+039.807997e+021.463388e+031.392299e+031.438592e+031.189829e+031.080684e+039.572457e+021.473574e+03
MAPE1.728617e+001.732159e+001.714489e+002.330128e+002.318451e+002.376710e+001.703826e+001.707995e+001.687196e+00
MAPE%4.028616e+054.087173e+053.795013e+054.943506e+055.341178e+053.357075e+053.990909e+054.035490e+053.813063e+05

基礎 SSSD 模型在未知地點與未來時段的測試中,MSPE 為 1.079069e+01、RMSPE 為 3.284918e+00。加入空間統計補強後,SSSD + autoFRK (R) 的誤差分別降至 1.019189e+01 與 3.192474e+00,顯示 autoFRK 能有效提升空間插值的穩定性。

進一步採用 Python 版本的整合模型 SSSD + autoFRK (Python),在相同區段中 MSPE 為 1.023162e+01、RMSPE 為 3.198691e+00,雖略高於 R 版的 MSPE,但整體趨勢更為穩定。相對百分比誤差(MSPE% 與 RMSPE%)分別為 1.415692e+06 與 1.189829e+03,波動幅度低於 R 版的 1.400531e+06 與 1.183440e+03。

此外,MAPE 由 SSSD 的 2.439223 降至 autoFRK (R) 的 2.377173,進一步下降至 autoFRK (Python) 的 2.376710。

整體而言,SSSD + autoFRK (Python) 在「未知地點 × 未來時段」的預測中展現出最低的誤差與最穩定的表現,顯示其在跨時間與空間的泛化能力上具優勢,但仍有進一步降低誤差的空間。

其他嘗試

上述實驗皆須在訓練完後改為使用 R 版本 autoFRK 輔助預測,線有心的優化方向是,嘗試從訓練到推論完全使用 Python 進行。目前遇到的全新問題是, Python 版本 autoFRKpredict 方法有提供參數 newloc 時,其記憶體使用率會飆升。其飆升根本原因為,由新地點資料長度建立索引時,後續依索引長度生成單位矩陣時,矩陣過大導致記憶體花銷過大,造成記憶體不足,目前最大記憶體估值甚至達到 5TB 。目前有嘗試將其改為稀疏矩陣的存放方式,但由於實驗時間及技術限制,可能需要一段時間進行大範圍大規模更新,目前先行擱置。

整合 Python 版本 SSSDautoFRK 後的程式碼如下,因理工一館進行停電檢修,目前暫未進行程式碼驗證與試跑。而配套的資料輸入與輸出腳本也同樣仍在撰寫中。

train.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
import argparse
import os
from typing import Optional, Union

import torch
import yaml

from sssd.core.model_specs import MODEL_PATH_FORMAT, setup_model
from sssd.training.trainer import DiffusionTrainer
from sssd.utils.logger import setup_logger
from sssd.utils.utils import calc_diffusion_hyperparams, display_current_time

LOGGER = setup_logger()


def fetch_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-m",
        "--model_config",
        type=str,
        default="configs/model.yaml",
        help="Model configuration",
    )
    parser.add_argument(
        "-t",
        "--training_config",
        type=str,
        default="configs/training.yaml",
        help="Training configuration",
    )
    return parser.parse_args()


def setup_output_directory(
    model_config: dict,
    training_config: dict,
) -> str:
    # Build output directory
    local_path = MODEL_PATH_FORMAT.format(
        T=model_config["diffusion"]["T"],
        beta_0=model_config["diffusion"]["beta_0"],
        beta_T=model_config["diffusion"]["beta_T"],
    )
    output_directory = os.path.join(training_config["output_directory"], local_path)

    if not os.path.isdir(output_directory):
        os.makedirs(output_directory)
        os.chmod(output_directory, 0o775)
    LOGGER.info("Output directory %s", output_directory)
    return output_directory


def run_job(
    model_config: dict,
    training_config: dict,
    device: Optional[Union[torch.device, str]],
) -> None:
    output_directory = setup_output_directory(model_config, training_config)

    diffusion_hyperparams = calc_diffusion_hyperparams(
        **model_config["diffusion"], device=device
    )
    net = setup_model(training_config["use_model"], model_config, device)

    LOGGER.info(display_current_time())
    trainer = DiffusionTrainer(
        data_path=training_config["data"]["train_path"],
        diffusion_hyperparams=diffusion_hyperparams,
        net=net,
        device=device,
        output_directory=output_directory,
        ckpt_iter=training_config.get("ckpt_iter"),
        n_iters=training_config.get("n_iters"),
        iters_per_ckpt=training_config.get("iters_per_ckpt"),
        iters_per_logging=training_config.get("iters_per_logging"),
        learning_rate=training_config.get("learning_rate"),
        only_generate_missing=training_config.get("only_generate_missing"),
        masking=training_config.get("masking"),
        missing_k=training_config.get("missing_k"),
        batch_size=training_config.get("batch_size"),
        enable_spatial_prediction=training_config.get("enable_spatial_prediction", True),  # New
        autoFRK_period=training_config.get("autoFRK_period"),  # New
        location_path=os.path.abspath(training_config["location_path"]),  # New
        logger=LOGGER,
    )
    trainer.train()

    LOGGER.info(display_current_time())


if __name__ == "__main__":
    args = fetch_args()

    with open(args.model_config, "rt") as f:
        model_config = yaml.safe_load(f.read())
    with open(args.training_config, "rt") as f:
        training_config = yaml.safe_load(f.read())

    LOGGER.info(f"Model spec: {model_config}")
    LOGGER.info(f"Training spec: {training_config}")

    if torch.cuda.device_count() > 0:
        LOGGER.info(f"Using {torch.cuda.device_count()} GPUs!")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    run_job(model_config, training_config, device)

trainer.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
import logging
import os
from typing import Any, Dict, Optional, Union, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.checkpoint import checkpoint  # New
from tqdm import tqdm
import numpy as np  # New

from sssd.core.model_specs import MASK_FN
from sssd.training.utils import training_loss
from sssd.data.utils import get_dataloader
from sssd.utils.logger import setup_logger
from sssd.utils.utils import find_max_epoch, std_normal  # New
from autoFRK import AutoFRK, to_tensor, garbage_cleaner  # New

LOGGER = setup_logger()


class DiffusionTrainer:
    """
    Train Diffusion Models

    Args:
        dataloader (DataLoader): The training dataloader.
        diffusion_hyperparams (Dict[str, Any]): Hyperparameters for the diffusion process.
        net (nn.Module): The neural network model to be trained.
        device (torch.device): The device to be used for training.
        output_directory (str): Directory to save model checkpoints.
        ckpt_iter (Optional[int, str]): The checkpoint iteration to be loaded; 'max' selects the maximum iteration.
        n_iters (int): Number of iterations to train.
        iters_per_ckpt (int): Number of iterations to save checkpoint.
        iters_per_logging (int): Number of iterations to save training log and compute validation loss.
        learning_rate (float): Learning rate for training.
        only_generate_missing (int): Option to generate missing portions of the signal only.
        masking (str): Type of masking strategy: 'mnr' for Missing Not at Random, 'bm' for Blackout Missing, 'rm' for Random Missing.
        missing_k (int): K missing time steps for each feature across the sample length.
        batch_size (int): Size of each training batch.
        logger (Optional[logging.Logger]): Logger object for logging, defaults to None.
    """

    def __init__(
        self,
        data_path: str,
        diffusion_hyperparams: Dict[str, Any],
        net: nn.Module,
        device: Optional[Union[torch.device, str]],
        output_directory: str,
        ckpt_iter: Union[str, int],
        n_iters: int,
        iters_per_ckpt: int,
        iters_per_logging: int,
        learning_rate: float,
        only_generate_missing: int,
        masking: str,
        missing_k: int,
        batch_size: int,
        enable_spatial_prediction: bool,  # New
        autoFRK_period: int,  # New
        location_path: str,  # New
        logger: Optional[logging.Logger] = None,
    ) -> None:
        loader, ts_mean, ts_std = get_dataloader(
            path=data_path,
            batch_size=batch_size,
            device=device,
        )
        self.dataloader = loader
        self.ts_mean = ts_mean
        self.ts_std = ts_std
        self.diffusion_hyperparams = diffusion_hyperparams
        self.net = nn.DataParallel(net).to(device)
        self.device = device
        self.output_directory = output_directory
        self.ckpt_iter = ckpt_iter
        self.n_iters = n_iters
        self.iters_per_ckpt = iters_per_ckpt
        self.iters_per_logging = iters_per_logging
        self.learning_rate = learning_rate
        self.only_generate_missing = only_generate_missing
        self.masking = masking
        self.missing_k = missing_k
        self.writer = SummaryWriter(f"{output_directory}/log")
        self.batch_size = batch_size
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
        self.enable_spatial_prediction = enable_spatial_prediction  # New
        self.autoFRK_period = autoFRK_period  # New
        self.location_path = location_path  # New
        self.real_data = self.dataloader.dataset.tensors[0].to(self.device) * self.ts_std + self.ts_mean # New
        self.real_data_shape = self.real_data.shape  # New
        self.logger = logger or LOGGER

        if self.masking not in MASK_FN:
            raise KeyError(f"Please enter a correct masking, but got {self.masking}")

    def _load_checkpoint(self) -> None:
        if self.ckpt_iter == "max":
            self.ckpt_iter = find_max_epoch(self.output_directory)
        if self.ckpt_iter >= 0:
            try:
                model_path = os.path.join(
                    self.output_directory, f"{self.ckpt_iter}.pkl"
                )
                checkpoint = torch.load(model_path, map_location="cpu")

                self.net.load_state_dict(checkpoint["model_state_dict"])
                if "optimizer_state_dict" in checkpoint:
                    self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

                self.logger.info(
                    f"Successfully loaded model at iteration {self.ckpt_iter}"
                )
            except Exception as e:
                self.ckpt_iter = -1
                self.logger.error(f"No valid checkpoint model found. Error: {e}")
        else:
            self.ckpt_iter = -1
            self.logger.info(
                "No valid checkpoint model found, start training from initialization."
            )

    def _save_model(self, n_iter: int) -> None:
        if n_iter > 0 and n_iter % self.iters_per_ckpt == 0:
            torch.save(
                {
                    "model_state_dict": self.net.state_dict(),
                    "optimizer_state_dict": self.optimizer.state_dict(),
                },
                os.path.join(self.output_directory, f"{n_iter}.pkl"),
            )

    def _update_mask(self, batch: torch.Tensor) -> torch.Tensor:
        transposed_mask = MASK_FN[self.masking](batch[0], self.missing_k)
        return (
            transposed_mask.permute(1, 0)
            .repeat(batch.size()[0], 1, 1)
            .to(self.device, dtype=torch.float32)
        )
    
    def _sampling_differentiable(
        self,
        net: torch.nn.Module,
        size: Tuple[int, int, int],
        diffusion_hyperparams: Dict[str, torch.Tensor],
        cond: torch.Tensor,
        mask: torch.Tensor,
        only_generate_missing: int = 0,
        device: Union[torch.device, str] = "cpu",
        sampling_mode: str = "ddpm",  # ✅ "ddpm" 或 "ddim"
        step_skip: Union[int, None] = None  # ✅ 若為 None,則自動偵測
    ) -> torch.Tensor:
        _dh = diffusion_hyperparams
        T, Alpha, Alpha_bar, Sigma = _dh["T"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"]
        x = torch.randn(size, device=device)
        mask = mask.to(device)
        cond = cond.to(device)

        # ✅ 自動偵測最適 step_skip
        if step_skip is None:
            if T <= 200:
                step_skip = 1       # 小模型:全步精細推論
            elif T <= 1000:
                step_skip = max(1, T // 200)  # 一般模型:約 200 步內完成
            else:
                step_skip = max(1, T // 400)  # 大型模型:控制在 400 步以內
            #LOGGER.info(f"T={T}, using step_skip={step_skip}")

        # ✅ 建立跳步時間表
        timesteps = list(range(T - 1, -1, -step_skip))
        if timesteps[-1] != 0:
            timesteps.append(0)

        for t in timesteps:
            def net_step(x_inner):
                if only_generate_missing == 1:
                    x_inner = x_inner * (1 - mask).float() + cond * mask.float()
                diffusion_steps = (t * torch.ones((size[0], 1), device=device))
                epsilon_theta = net((x_inner, cond, mask, diffusion_steps))

                # ✅ 根據 sampling_mode 決定取樣公式
                if sampling_mode == "ddpm":
                    # DDPM:隨機性版本
                    x_next = (x_inner - (1 - Alpha[t]) / torch.sqrt(1 - Alpha_bar[t]) * epsilon_theta) / torch.sqrt(Alpha[t])
                    if t > 0:
                        x_next = x_next + Sigma[t] * torch.randn_like(x_next)
                elif sampling_mode == "ddim":
                    # DDIM:確定性版本,支援跳步
                    if t == 0:
                        x_next = torch.sqrt(Alpha_bar[t]) * (
                            (x_inner - torch.sqrt(1 - Alpha_bar[t]) * epsilon_theta)
                            / torch.sqrt(Alpha_bar[t])
                        )
                    else:
                        Alpha_bar_prev = Alpha_bar[max(t - step_skip, 0)]
                        x0_pred = (x_inner - torch.sqrt(1 - Alpha_bar[t]) * epsilon_theta) / torch.sqrt(Alpha_bar[t])
                        x_next = torch.sqrt(Alpha_bar_prev) * x0_pred + torch.sqrt(1 - Alpha_bar_prev) * epsilon_theta
                else:
                    raise ValueError(f"Unknown sampling_mode: {sampling_mode}")

                return x_next

            # ✅ checkpoint 保留梯度可逆性
            x = checkpoint(net_step, x, use_reentrant=False)

        return x
    
    def _sssd_prediction_step(self) -> torch.Tensor:
        LOGGER.info(f"Start SSSD prediction step")

        all_batches = []
        for (batch,) in tqdm(self.dataloader, desc=f"{self.n_iter}-th predicting TS"):
            mask = self._update_mask(batch)
            batch = batch.permute(0, 2, 1)

            # ✅ 新增:自動 step_skip 與 sampling_mode
            batch_generated = self._sampling_differentiable(
                net=self.net,
                size=batch.shape,
                diffusion_hyperparams=self.diffusion_hyperparams,
                cond=batch,
                mask=mask,
                only_generate_missing=self.only_generate_missing,
                device=self.device,
                sampling_mode=getattr(self, "sampling_mode", "ddpm"),
                step_skip=getattr(self, "step_skip", 10)  # 若未指定則自動偵測
            )

            all_batches.append(batch_generated)
            garbage_cleaner()

        sssd_prediction = torch.cat(all_batches, dim=0).permute(1, 2, 0)
        return sssd_prediction

    def _autoFRK_step(self,
                      sssd_prediction,
                      ) -> torch.Tensor:
        # autoFRK
        LOGGER.info(f"Start autoFRK inference step")
        dtype = sssd_prediction.dtype
        device = sssd_prediction.device

        # unstandarize
        ts_mean = self.ts_mean.permute(2, 1, 0).expand(-1, sssd_prediction.shape[1], -1)
        ts_std = self.ts_std.permute(2, 1, 0).expand(-1, sssd_prediction.shape[1], -1)
        sssd_prediction = sssd_prediction * ts_std + ts_mean
        
        V, T, N = sssd_prediction.shape
        total_steps = V * T

        loc = to_tensor(np.load(self.location_path), dtype=dtype, device=device)
        autoFRK_result = torch.empty_like(sssd_prediction, dtype=dtype, device=device)
        frk = AutoFRK(
            logger_level=30,
            dtype=dtype,
            device=device,
        )

        # 只在第一次初始化時編譯 AutoFRK
        # if not hasattr(self, "_compiled_frk") or self._compiled_frk is None:
        #     frk = AutoFRK(
        #         logger_level=30,
        #         dtype=dtype,
        #         device=device,
        #     )
        #     self._compiled_frk = torch.compile(frk, mode="reduce-overhead", dynamic=True)
        # LOGGER.info("Compiled AutoFRK for the first time with dynamic=True.")
        # frk = self._compiled_frk

        mrts = None
        with tqdm(total=total_steps, desc=f"{self.n_iter}-th predicting autoFRK") as pbar:
            for variable in range(V):
                for time_index in range(T):
                    data_slice = sssd_prediction[variable, time_index, :]

                    try:
                        _ = frk.forward(
                            data=data_slice,
                            loc=loc,
                            G = mrts,
                            method="fast",
                            tps_method="rectangular",
                            requires_grad=True
                        )
                        pred = frk.predict()['pred.value']
                        mrts = frk.obj['G']['MRTS'] if mrts is None else mrts

                    except torch._C._LinAlgError:
                        LOGGER.warning(f"Skipped variable={variable}, time={time_index} due to ill-conditioned matrix")
                        pred = torch.zeros((N, 1), dtype=dtype, device=device, requires_grad=True)
                    
                    autoFRK_result[variable, time_index, :] = pred.T
                    pbar.update(1)
        autoFRK_result = autoFRK_result.permute(2, 1, 0)

        # return
        if autoFRK_result.shape != self.real_data_shape:
            error_msg = f"Shape mismatch: autoFRK_result {autoFRK_result.shape} != real_data {self.real_data_shape}"
            LOGGER.error(error_msg)
            raise ValueError(error_msg)

        return autoFRK_result

    def _train_per_epoch(self) -> torch.Tensor:

        # SSSD training
        for (batch,) in tqdm(self.dataloader, desc=f"{self.n_iter}-th   training TS"):
            batch = batch.to(self.device)
            mask = self._update_mask(batch)
            loss_mask = ~mask.bool()
            loss_function=nn.MSELoss()

            batch = batch.permute(0, 2, 1)
            assert batch.size() == mask.size() == loss_mask.size()

            self.optimizer.zero_grad()
            loss = training_loss(
                model=self.net,
                loss_function=loss_function,
                training_data=(batch, batch, mask, loss_mask),
                diffusion_parameters=self.diffusion_hyperparams,
                generate_only_missing=self.only_generate_missing,
                device=self.device,
            )
            loss.backward()
            self.optimizer.step()

        if self.enable_spatial_prediction and self.n_iter % self.autoFRK_period == 0:
            LOGGER.info(f"Iteration {self.n_iter}: Start Spatial Prediction step")
            sssd_prediction = self._sssd_prediction_step()
            garbage_cleaner()
            autoFRK_result = self._autoFRK_step(sssd_prediction=sssd_prediction)

            # compute loss
            self.optimizer.zero_grad()
            loss = loss_function(
                autoFRK_result,
                self.real_data
            )

            # update model
            loss.backward()
            self.optimizer.step()
            
            LOGGER.info(f"Iteration {self.n_iter}: Spatial Prediction step done, loss: {loss.item()}")

        return loss

    def train(self) -> None:
        self._load_checkpoint()

        n_iter_start = (
            self.ckpt_iter + 2 if self.ckpt_iter == -1 else self.ckpt_iter + 1
        )
        self.logger.info(f"Start the {n_iter_start} iteration")

        for n_iter in range(n_iter_start, self.n_iters + 1):
            self.n_iter = n_iter
            loss = self._train_per_epoch()
            self.writer.add_scalar("Train/Loss", loss.item(), n_iter)
            if n_iter % self.iters_per_logging == 0:
                self.logger.info(f"Iteration: {n_iter} \tLoss: { loss.item()}")
            self._save_model(n_iter)

utils.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import random
from typing import Union

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset


def merge_all_time(df: pd.DataFrame) -> pd.DataFrame:
    """
    Fill in all time points and create rows for missing values.

    Args:
    df (DataFrame): DataFrame containing 'Date', 'Zone', and 'Load' columns.

    Returns:
    DataFrame: A DataFrame with the same columns. The number of rows is hours_df.shape[0] * 11.
    """
    # Create a DataFrame with all hourly time points
    hours_df = pd.DataFrame(
        {"Date": pd.date_range(start=df["Date"].min(), end=df["Date"].max(), freq="1H")}
    )

    zones = df["Zone"].unique()
    result_all_time = pd.DataFrame()

    for zone in zones:
        # Extract data for the current zone
        load_zone = df.loc[df["Zone"] == zone]

        # Merge with hourly time points
        result = pd.merge(hours_df, load_zone, on="Date", how="left")
        result["Zone"] = zone

        result_all_time = pd.concat([result_all_time, result], axis=0)

    return result_all_time


def load_testing_data(test_data_path: str, num_samples: int) -> torch.Tensor:
    """
    Load and prepare testing data for generation.

    Args:
    - test_data_path (str): Path to the testing data file.
    - num_samples (int): Number of samples per batch.

    Returns:
    - torch.Tensor: Tensor containing the testing data prepared for generation.
    """
    # Load testing data
    testing_data = np.load(test_data_path)

    # Split testing data into batches
    testing_data_batches = np.split(testing_data, testing_data.shape[0] // num_samples)

    # Convert to numpy array and then to torch tensor
    testing_data_tensor = torch.from_numpy(np.array(testing_data_batches)).float()

    # Move tensor to CUDA device if available
    if torch.cuda.is_available():
        testing_data_tensor = testing_data_tensor.cuda()

    return testing_data_tensor


def load_and_split_training_data(
    training_data_load: np.ndarray,
    batch_num: int,
    batch_size: int,
    device: torch.device,
) -> torch.Tensor:
    """
    Load and split training data into batches.

    Args:
        training_data_load (np.ndarray): The training data to load and split.
        batch_num (int): The number of batches to create.
        batch_size (int): The size of each batch.
        device (torch.device): The device to move the data to.

    Returns:
        torch.Tensor: The training data split into batches and moved to the specified device.
    """
    total_samples = training_data_load.shape[0]
    if batch_size > total_samples:
        raise ValueError(
            "Batch size exceeds the total number of samples in the training data"
        )

    indices = random.sample(range(total_samples), batch_num * batch_size)
    training_data = training_data_load[indices]
    training_data = np.split(training_data, batch_num, 0)
    training_data = np.array(training_data)
    return torch.from_numpy(training_data).to(device, dtype=torch.float32)


def get_dataloader(
    path: str,
    batch_size: int,
    is_shuffle: bool = True,
    device: Union[str, torch.device] = "cpu",
    num_workers: int = 0,
    normalize: bool = True,  # New
    inference: bool = False,  # New
    missing_k: int = None  # New
) -> DataLoader:
    """
    Get a PyTorch DataLoader for the dataset stored at the given path.

    Args:
        path (str): Path to the dataset file.
        batch_size (int): Size of each batch.
        is_shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
        device (Union[str, torch.device], optional): Device to move the data to. Defaults to "cpu".
        num_workers (int, optional): Number of subprocesses to use for data loading. Defaults to 8.

    Returns:
        DataLoader: PyTorch DataLoader for the dataset.
    """
    data = torch.from_numpy(np.load(path)).to(dtype=torch.float32)

    if normalize:
        ts_mean = data.mean(dim=1, keepdim=True)
        ts_std = data.std(dim=1, unbiased=False, keepdim=True)
        ts_std[ts_std == 0] = 1.0

        data = (data - ts_mean) / ts_std
    else:
        ts_mean = data.mean(dim=1, keepdim=True) * 0
        ts_std = data.std(dim=1, unbiased=False, keepdim=True) * 0 + 1

    if inference and missing_k is not None:
        zeros = torch.zeros((data.shape[0], missing_k, data.shape[2]), dtype=data.dtype, device=data.device)
        data = torch.cat([data, zeros], dim=1) 
    elif inference:
        error_msg = f"In inference mode, missing_k must be specified."
        raise ValueError(error_msg)
        
    dataset = TensorDataset(data)
    pin_memory = device == "cuda" or device == torch.device("cuda")
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=is_shuffle,
        pin_memory=pin_memory,
        num_workers=num_workers,
    )

    return loader, ts_mean, ts_std

infer.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
from typing import Optional, Union

import os  # New
import torch
import torch.nn as nn
import yaml

from sssd.core.model_specs import MODEL_PATH_FORMAT, setup_model
from sssd.inference.generator import DiffusionGenerator
from sssd.utils.logger import setup_logger
from sssd.utils.utils import calc_diffusion_hyperparams, display_current_time

LOGGER = setup_logger()


def fetch_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-m",
        "--model_config",
        type=str,
        default="configs/model.yaml",
        help="Model configuration",
    )
    parser.add_argument(
        "-i",
        "--inference_config",
        type=str,
        default="configs/inference_config.yaml",
        help="Inference configuration",
    )
    parser.add_argument(
        "-ckpt_iter",
        "--ckpt_iter",
        default="max",
        help='Which checkpoint to use; assign a number or "max" to find the latest checkpoint',
    )
    return parser.parse_args()


def run_job(
    model_config: dict,
    inference_config: dict,
    device: Optional[Union[torch.device, str]],
    ckpt_iter: Union[str, int],
) -> None:
    trials = inference_config.get("trials")
    batch_size = inference_config["batch_size"]

    local_path = MODEL_PATH_FORMAT.format(
        T=model_config["diffusion"]["T"],
        beta_0=model_config["diffusion"]["beta_0"],
        beta_T=model_config["diffusion"]["beta_T"],
    )

    diffusion_hyperparams = calc_diffusion_hyperparams(
        **model_config["diffusion"], device=device
    )
    LOGGER.info(display_current_time())
    net = setup_model(inference_config["use_model"], model_config, device)

    # Check if multiple GPUs are available
    if torch.cuda.device_count() > 0:
        net = nn.DataParallel(net)

    data_names = ["imputation", "original", "mask"]
    directory = inference_config["output_directory"]

    if trials > 1:
        directory += "_{trial}"

    for trial in range(1, trials + 1):
        LOGGER.info(f"The {trial}th inference trial")
        saved_data_names = data_names if trial == 0 else data_names[0]

        DiffusionGenerator(
            net=net,
            device=device,
            diffusion_hyperparams=diffusion_hyperparams,
            data_path=inference_config["data"]["test_path"],
            local_path=local_path,
            output_directory=directory.format(trial=trial) if trials > 1 else directory,
            ckpt_path=inference_config["ckpt_path"],
            ckpt_iter=ckpt_iter,
            batch_size=batch_size,
            masking=inference_config["masking"],
            missing_k=inference_config["missing_k"],
            only_generate_missing=inference_config["only_generate_missing"],
            saved_data_names=saved_data_names,
            enable_spatial_prediction=inference_config.get("enable_spatial_inference", True),  # New
            known_location_path=os.path.abspath(inference_config["known_location_path"]),  # New
            unknown_location_path=os.path.abspath(inference_config["unknown_location_path"]),  # New
        ).generate()

        LOGGER.info(f"Inference complete")
        LOGGER.info(display_current_time())


if __name__ == "__main__":
    args = fetch_args()

    with open(args.model_config, "rt") as f:
        model_config = yaml.safe_load(f.read())
    with open(args.inference_config, "rt") as f:
        inference_config = yaml.safe_load(f.read())

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 0:
        LOGGER.info(f"Using {torch.cuda.device_count()} GPUs!")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    run_job(model_config, inference_config, device, args.ckpt_iter)

generator.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import logging
import os
from typing import Dict, Iterable, Optional, Union

import numpy as np
import torch
from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error
from torch.utils.data import DataLoader
from tqdm import tqdm  # New

from sssd.core.model_specs import MASK_FN
from sssd.data.utils import get_dataloader  # New
from sssd.utils.logger import setup_logger
from sssd.utils.utils import find_max_epoch, sampling
from autoFRK import AutoFRK, to_tensor, garbage_cleaner  # New

LOGGER = setup_logger()


class DiffusionGenerator:
    """
    Generate data based on ground truth.

    Args:
        net (torch.nn.Module): The neural network model.
        device (Optional[Union[torch.device, str]]): The device to run the model on (e.g., 'cuda' or 'cpu').
        diffusion_hyperparams (dict): Dictionary of diffusion hyperparameters.
        local_path (str): Local path format for the model.
        testing_data (torch.Tensor): Tensor containing testing data.
        output_directory (str): Path to save generated samples.
        batch_size (int): Number of samples to generate.
        ckpt_path (str): Checkpoint directory.
        ckpt_iter (str): Pretrained checkpoint to load; 'max' selects the maximum iteration.
        masking (str): Type of masking: 'mnr' (missing not at random), 'bm' (black-out), 'rm' (random missing).
        missing_k (int): Number of missing time points for each channel across the length.
        only_generate_missing (int): Whether to generate only missing portions of the signal:
                                      - 0 (all sample diffusion),
                                      - 1 (generate missing portions only).
        saved_data_names (Iterable[str], optional): Names of data arrays to save (default is ("imputation", "original", "mask")).
        logger (Optional[logging.Logger], optional): Logger object for logging messages (default is None).
    """

    def __init__(
        self,
        net: torch.nn.Module,
        device: Optional[Union[torch.device, str]],
        diffusion_hyperparams: dict,
        local_path: str,
        data_path: str,
        output_directory: str,
        batch_size: int,
        ckpt_path: str,
        ckpt_iter: str,
        masking: str,
        missing_k: int,
        only_generate_missing: int,
        enable_spatial_prediction: bool,  # New
        known_location_path: str,  # New
        unknown_location_path: str,  # New
        saved_data_names: Iterable[str] = ("imputation", "original", "mask"),
        logger: Optional[logging.Logger] = None,
    ):
        self.net = net
        self.device = device
        self.diffusion_hyperparams = diffusion_hyperparams
        self.local_path = local_path
        loader, ts_mean, ts_std = get_dataloader(
            path=data_path,
            batch_size=batch_size,
            device=device,
            inference=True,
            missing_k=missing_k
        )
        self.dataloader = loader
        self.ts_mean = ts_mean
        self.ts_std = ts_std
        self.batch_size = batch_size
        self.masking = masking
        self.missing_k = missing_k
        self.only_generate_missing = only_generate_missing
        self.enable_spatial_prediction = enable_spatial_prediction  # New
        self.known_location_path = known_location_path  # New
        self.unknown_location_path = unknown_location_path  # New
        self.missing_k = missing_k  # New
        self.logger = logger or LOGGER

        self.output_directory = self._prepare_output_directory(
            output_directory, local_path, ckpt_iter
        )
        self.saved_data_names = saved_data_names
        self._load_checkpoint(ckpt_path, ckpt_iter)

    def _load_checkpoint(self, ckpt_path: str, ckpt_iter: str) -> None:
        """Load a checkpoint for the given neural network model."""
        ckpt_path = os.path.join(ckpt_path, self.local_path)
        if ckpt_iter == "max":
            ckpt_iter = find_max_epoch(ckpt_path)
        model_path = os.path.join(ckpt_path, f"{ckpt_iter}.pkl")
        try:
            checkpoint = torch.load(model_path, map_location="cpu")
            self.net.load_state_dict(checkpoint["model_state_dict"])
            self.logger.info(f"Successfully loaded model at iteration {ckpt_iter}")
        except FileNotFoundError as e:
            raise FileNotFoundError(f"Model file not found at {model_path}") from e
        except Exception as e:
            raise Exception(f"Failed to load model: {e}")

    def _prepare_output_directory(
        self, output_directory: str, local_path: str, ckpt_iter: str
    ) -> str:
        """Prepare the output directory to save generated samples."""
        ckpt_iter_str = (
            "max"
            if ckpt_iter == "max"
            else f"imputation_multiple_{int(ckpt_iter) // 1000}k"
        )
        output_directory = os.path.join(output_directory, local_path, ckpt_iter_str)
        os.makedirs(output_directory, exist_ok=True)
        os.chmod(output_directory, 0o775)
        self.logger.info(f"Output directory: {output_directory}")
        return output_directory

    def _update_mask(self, batch: torch.Tensor) -> torch.Tensor:
        """Update mask based on the given batch."""
        transposed_mask = MASK_FN[self.masking](batch[0], self.missing_k)
        return (
            transposed_mask.permute(1, 0)
            .repeat(batch.size()[0], 1, 1)
            .to(self.device, dtype=torch.float32)
        )

    def _save_data(
        self,
        results: Dict[str, np.ndarray],
        index: int,
    ) -> None:
        """Save generated_series, batch, and mask data arrays."""

        for name, data in results.items():
            if name in self.saved_data_names:
                filename = f"{name}{index}.npy"
                np.save(os.path.join(self.output_directory, filename), data)

    def _autoFRK_generate(
        self,
        sssd_inference,
        with_known_loc: bool = True
    ) -> torch.Tensor:
        LOGGER.info(f"Start autoFRK inference step")
        dtype = torch.float16
        device = self.device
        sssd_inference = to_tensor(sssd_inference, dtype = dtype, device = device)
        ts_mean = to_tensor(self.ts_mean, dtype = dtype, device = device)
        ts_std = to_tensor(self.ts_std, dtype = dtype, device = device)
        sssd_inference = sssd_inference * ts_std + ts_mean

        N, T, V = sssd_inference.shape
        total_steps = V * T

        known_loc = to_tensor(np.load(self.known_location_path), dtype=dtype, device=device)
        unknown_loc = to_tensor(np.load(self.unknown_location_path), dtype=dtype, device=device)
        if known_loc.ndim == 1:
            known_loc = known_loc.view(-1, 1)
        if unknown_loc.ndim == 1:
            unknown_loc = unknown_loc.view(-1, 1)
        missing_loc = unknown_loc.shape[0]
        autoFRK_inference = torch.zeros((missing_loc, T, V), dtype=dtype, device=device)
        frk = AutoFRK(
            logger_level=30,
            dtype=dtype,
            device=device,
        )

        mrts = None
        with tqdm(total=total_steps, desc=f"inferencing autoFRK") as pbar:
            for variable in range(V):
                for time_index in range(T):
                    data_slice = sssd_inference[:, time_index, variable]

                    try:
                        _ = frk.forward(
                            data=data_slice,
                            loc=known_loc,
                            G = mrts,
                            method="fast",
                            tps_method="rectangular",
                            requires_grad=False
                        )
                        pred = frk.predict(
                            newloc = unknown_loc
                        )['pred.value']
                        mrts = frk.obj['G']['MRTS'] if mrts is None else mrts

                    except torch._C._LinAlgError:
                        LOGGER.warning(f"Skipped variable={variable}, time={time_index} due to ill-conditioned matrix")
                        pred = torch.zeros((missing_loc, 1), dtype=dtype, device=device, requires_grad=True)
                    
                    autoFRK_inference[:, time_index, variable] = pred.T
                    pbar.update(1)

            if with_known_loc:
                autoFRK_inference = torch.cat([sssd_inference, autoFRK_inference], dim=0)

        return autoFRK_inference

    def generate(self) -> list:
        """Generate samples using the given neural network model."""
        all_generated = []
        for index, (batch,) in enumerate(self.dataloader):
            batch = batch.to(self.device)
            mask = self._update_mask(batch)
            batch = batch.permute(0, 2, 1)

            generated_series = (
                sampling(
                    net=self.net,
                    size=batch.shape,
                    diffusion_hyperparams=self.diffusion_hyperparams,
                    cond=batch,
                    mask=mask,
                    only_generate_missing=self.only_generate_missing,
                    device=self.device,
                )
                .detach()
                .cpu()
                .numpy()
            )
            all_generated.append(generated_series)
        sssd_inference = np.concatenate(all_generated, axis=0).transpose(0, 2, 1)

        if self.enable_spatial_prediction:
            autoFRK_inference = self._autoFRK_generate(
                sssd_inference = sssd_inference,
                with_known_loc = True
            ).detach().cpu().numpy()
            self._save_data(autoFRK_inference, 0)
        else:
            self._save_data(sssd_inference, 0)

training.yaml

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Training configuration
batch_size: 64  # Batch size
output_directory: "./results/weather2k"  # Output directory for checkpoints and logs
ckpt_iter: "max"  # Checkpoint mode (max or min)
iters_per_ckpt: 100  # Checkpoint frequency (number of epochs)
iters_per_logging: 100  # Log frequency (number of iterations)
n_iters: 3800  # Maximum number of iterations
learning_rate: 0.0005  # Learning rate

# Additional training settings
only_generate_missing: true  # Generate missing values only
use_model: 2  # Model to use for training
masking: "forecast"  # Masking strategy for missing values
missing_k: 19  # Number of missing values

# Data paths
data:
  train_path: "./datasets/weather2k/train_sssd.npy"  # Path to training data

# autoFRK config
enable_spatial_prediction: true  # Enable spatial prediction step
autoFRK_period: 100  # Frequency of autoFRK updates (in how many iterations)
location_path: "./datasets/weather2k/stations_known_locations.npy"  # Path to known locations

inference.yaml

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
# Inference configuration
batch_size: 64  # Batch size for inference
output_directory: "./results/weather2k/inference"  # Output directory for inference results
ckpt_path: "./results/weather2k"  # Path to checkpoint for inference
trials: 1 # Replications

# Additional training settings
only_generate_missing: true  # Generate missing values only
use_model: 2  # Model to use for training
masking: "forecast"  # Masking strategy for missing values
missing_k: 19  # Number of missing values

# Data paths
data:
  test_path: "./datasets/weather2k/test_sssd.npy"  # Path to test data

# autoFRK config
enable_spatial_inference: true  # Enable spatial prediction step
known_location_path: "./datasets/weather2k/stations_known_locations.npy"  # Path to known locations
unknown_location_path: "./datasets/weather2k/stations_unknown_locations.npy"  # Path to unknown locations

參考資料