目錄

1140304 meeting

$SSSD^{S4}$ 模型輸入與輸出探討

封面圖片是 $SSSD^{S4}$ 模型的運作過程,2024 年 2 月 28 日擷取自 Github 。

$SSSD$ 模型的資料集

我嘗試使用原作者於 Github 上所提供的範例,使用 MuJoCo 資料集進行訓練。由原作者的 Github ,我們可以知道這是從 NRTSI 中所提取的資料集。 NRTSI 提供一系列適用於訓練時間序列模型的資料集,除了提供打包成 .npy 格式的訓練集、測試集檔案外,也提供其他用於填補時間序列的方法。

上週 meeting 中的 $SSSD^{S4}$ 輸入資料 train_mujoco.npy ,我們可以得到

  1. 此資料是 NumPy 的陣列類型,並且是 $8000 \times 100 \times 14$ 的三維陣列。
  2. 依據 config 設定,所填補的結果皆為 $500 \times 14 \times 100$ 的陣列。

以下為用於測試的 Python 腳本與執行結果。

載入模組、設定函數與參數。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
# import modele
import numpy as np
import os
import pickle

# function
def check_dataset(file_path):
    data = np.load(file_path, allow_pickle=True)
    print(f'Type: {type(data)}, Shape: {data.shape}')
    print(f'{data}\n')
    return data

# main program
## parameter
output_dir_path = r'D:\Code\sssd_cp_learning_and_testing\learning_and_testing\SSSD\results\mujoco\90\T200_beta00.0001_betaT0.02'

查看 Mujoco 資料集的訓練集

1
2
3
4
## model input dataset (mujoco)
print(f'train_mujoco.npy')
file_path = r"D:\Code\sssd_cp_learning_and_testing\learning_and_testing\SSSD\datasets\Mujoco\train_mujoco.npy"
train_mujoco = check_dataset(file_path)
執行結果參考
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
train_mujoco.npy
Type: <class 'numpy.ndarray'>, Shape: (8000, 100, 14)
[[[1.0188564  0.34472078 0.9193448  ... 1.5551598  0.8262319  1.449314  ]
  [1.0205188  0.34052694 0.92114544 ... 1.5594999  0.82539403 1.450077  ]
  [1.0221945  0.33620164 0.92290837 ... 1.563845   0.82457316 1.4503349 ]
  ...
  [1.0996897  0.09380768 0.91228646 ... 2.0254884  0.74934673 1.4205385 ]
  [1.0996888  0.093808   0.9123445  ... 2.0256474  0.7493937  1.421181  ]
  [1.0996878  0.09380717 0.91238856 ... 2.0259614  0.749325   1.4219254 ]]

 [[1.1299545  0.56555724 0.8105636  ... 2.0017648  0.7691426  1.3971107 ]
  [1.1328042  0.5636149  0.8118937  ... 2.001997   0.76909196 1.3971739 ]
  [1.1356547  0.5615865  0.81322247 ... 2.00223    0.7690404  1.397237  ]
  ...
  [1.3818084  0.04634438 0.8966854  ... 1.9546272  0.7980857  1.5649099 ]
  [1.3814888  0.04611954 0.8954035  ... 1.9563212  0.7971047  1.5648654 ]
  [1.3811697  0.04583227 0.89410746 ... 1.9577047  0.79626673 1.5648133 ]]

 [[0.9395856  0.26847756 0.8532757  ... 2.0925462  0.88624686 1.4330561 ]
  [0.9394652  0.26569197 0.85121363 ... 2.0933154  0.8870823  1.4280789 ]
  [0.9393492  0.26279953 0.84915906 ... 2.0925698  0.8863424  1.4268135 ]
  ...
  [0.93903416 0.06014196 0.7692848  ... 2.0294168  0.7476863  1.4105136 ]
  [0.93900484 0.06018732 0.769219   ... 2.0295722  0.7476749  1.4098489 ]
  [0.93897474 0.0602343  0.7691506  ... 2.0297213  0.74766856 1.4091637 ]]

 ...

 [[0.9482162  0.46353704 0.81380767 ... 1.0249602  1.2738136  1.350296  ]
  [0.94537544 0.46859425 0.81512964 ... 1.0194359  1.275639   1.3530833 ]
  [0.9427458  0.47382125 0.81637526 ... 1.0178125  1.2746325  1.3559105 ]
  ...
  [0.70087945 0.6286735  0.8703839  ... 2.1573842  0.6628122  1.4355185 ]
  [0.6981438  0.6258822  0.870551   ... 2.1580424  0.6628611  1.4355744 ]
  [0.6954119  0.6230053  0.87070715 ... 2.1587005  0.6629189  1.4356297 ]]

 [[1.084146   0.13920696 0.8108991  ... 2.304378   0.7502878  1.465541  ]
  [1.0859725  0.13577522 0.8078264  ... 2.4035823  0.7357796  1.4806526 ]
  [1.0877669  0.13278663 0.80474424 ... 2.4574542  0.7256363  1.4895606 ]
  ...
  [1.1298585  0.15936916 0.84470177 ... 2.0286796  0.7642296  1.4276175 ]
  [1.1294307  0.1585352  0.8442135  ... 2.029027   0.7654774  1.4275758 ]
  [1.1290065  0.15761633 0.8437144  ... 2.0292292  0.76694435 1.4277219 ]]

 [[0.8497951  0.628934   1.0146345  ... 2.0245514  0.81360376 1.431919  ]
  [0.84885335 0.6278094  1.0138252  ... 2.0257835  0.8132482  1.4319915 ]
  [0.84791064 0.6265908  1.0130211  ... 2.0270183  0.81288934 1.4320644 ]
  ...
  [0.7533675  0.12601541 1.0557469  ... 2.0797296  0.7346148  1.4350158 ]
  [0.752361   0.11651292 1.0562434  ... 2.0798457  0.73452085 1.4350255 ]
  [0.7513538  0.10692114 1.0567361  ... 2.0799599  0.73442733 1.4350348 ]]]

查看模型填補結果 imputation0.npy

1
2
3
4
5
## model output
### imputation
print(f'imputation0.npy')
file_path = os.path.join(output_dir_path, 'imputation0.npy')
imputation0 = check_dataset(file_path)
執行結果參考
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
imputation0.npy
Type: <class 'numpy.ndarray'>, Shape: (500, 14, 100)
[[[0.9779899  0.98437977 1.0914965  ... 1.1645157  0.92187357 0.9408777 ]
  [0.45340696 0.45619234 0.45239526 ... 0.09200554 0.09504005 0.09845223]
  [1.0418788  1.0411524  1.0791966  ... 1.0726292  1.0911459  1.0768353 ]
  ...
  [1.9333378  1.9531232  1.9822652  ... 1.9063805  1.9313103  1.9919088 ]
  [0.7332712  0.7334014  0.72775054 ... 0.79629755 0.74248964 0.77277774]
  [1.3981225  1.3949718  1.4293809  ... 1.3827648  1.4142646  1.4201723 ]]

 [[0.8017445  0.8029894  0.7908626  ... 0.715918   0.6784949  0.691612  ]
  [0.28352863 0.27537826 0.2731     ... 0.05175451 0.03815852 0.04471585]
  [1.1245177  1.1291656  1.1952096  ... 1.182712   1.1503464  1.1865829 ]
  ...
  [1.8880249  1.8927475  1.8664936  ... 1.9198602  1.8935788  1.971485  ]
  [0.66902924 0.66292894 0.67104393 ... 0.74909174 0.7419653  0.7354708 ]
  [1.3671505  1.3676192  1.4229809  ... 1.448961   1.4355994  1.447378  ]]

 [[1.035299   1.0503275  1.1367137  ... 1.2797318  1.0733584  1.0659832 ]
  [0.41497543 0.4130687  0.42461908 ... 0.08333386 0.05842083 0.0736234 ]
  [1.0304348  1.0206361  1.0440897  ... 1.0806819  1.0630673  1.0686685 ]
  ...
  [1.8276113  1.8258247  1.8390502  ... 1.9030184  1.9482071  1.9963055 ]
  [0.72919995 0.71213514 0.72257394 ... 0.7469197  0.7098038  0.71767974]
  [1.374692   1.3816208  1.4401623  ... 1.3709145  1.3681242  1.3561548 ]]

 ...

 [[1.0017495  1.0141724  1.030208   ... 1.1663914  1.0523969  1.0220131 ]
  [0.3814863  0.39106598 0.3751795  ... 0.03352571 0.04184113 0.0321909 ]
  [0.82935625 0.8351018  0.84719926 ... 0.78193384 0.79715204 0.8116017 ]
  ...
  [1.9513766  1.9874988  1.9917512  ... 1.6337692  1.6832944  1.7768489 ]
  [0.9911371  0.86839604 0.8214125  ... 0.7012732  0.62618077 0.6505313 ]
  [1.3318325  1.3209419  1.3354611  ... 1.4256169  1.4097162  1.4321613 ]]

 [[0.8313655  0.8345677  0.8353071  ... 0.7665064  0.69832975 0.70574576]
  [0.15491015 0.1849071  0.19601335 ... 0.04784474 0.05161434 0.04703766]
  [0.78100544 0.7520876  0.7652805  ... 0.7345611  0.7311005  0.7375943 ]
  ...
  [1.9338595  1.9964237  2.0456364  ... 1.8495271  1.8366884  1.9121118 ]
  [0.76967365 0.7433044  0.72210073 ... 0.7822682  0.7486641  0.7456464 ]
  [1.3081641  1.3280656  1.3773645  ... 1.4195722  1.4250567  1.4304402 ]]

 [[0.8578036  0.8595259  0.8822513  ... 0.81210816 0.75448877 0.75732774]
  [0.26452947 0.24751675 0.24756058 ... 0.07032993 0.0835114  0.08039934]
  [1.0795563  1.0977736  1.1368525  ... 1.1014563  1.1016102  1.0979248 ]
  ...
  [1.5444437  1.5485339  1.5500091  ... 1.9859809  2.0041018  2.034277  ]
  [0.76173866 0.74596816 0.74051833 ... 0.7301484  0.71769965 0.7053322 ]
  [1.3084735  1.3045095  1.2793151  ... 1.4134452  1.411825   1.424677  ]]]

查看模型遮罩 mask0.npy

1
2
3
print(f'mask0.npy')
file_path = os.path.join(output_dir_path, 'mask0.npy')
mask0 = check_dataset(file_path)
執行結果參考
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
mask0.npy
Type: <class 'numpy.ndarray'>, Shape: (500, 14, 100)
[[[0. 0. 1. ... 1. 0. 0.]
  [0. 0. 0. ... 0. 0. 1.]
  [0. 0. 1. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 1.]
  [1. 0. 0. ... 1. 0. 0.]
  [0. 0. 1. ... 0. 0. 0.]]

 [[0. 0. 1. ... 1. 0. 0.]
  [0. 0. 0. ... 0. 0. 1.]
  [0. 0. 1. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 1.]
  [1. 0. 0. ... 1. 0. 0.]
  [0. 0. 1. ... 0. 0. 0.]]

 [[0. 0. 1. ... 1. 0. 0.]
  [0. 0. 0. ... 0. 0. 1.]
  [0. 0. 1. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 1.]
  [1. 0. 0. ... 1. 0. 0.]
  [0. 0. 1. ... 0. 0. 0.]]

 ...

 [[0. 0. 1. ... 1. 0. 0.]
  [0. 0. 0. ... 0. 0. 1.]
  [0. 0. 1. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 1.]
  [1. 0. 0. ... 1. 0. 0.]
  [0. 0. 1. ... 0. 0. 0.]]

 [[0. 0. 1. ... 1. 0. 0.]
  [0. 0. 0. ... 0. 0. 1.]
  [0. 0. 1. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 1.]
  [1. 0. 0. ... 1. 0. 0.]
  [0. 0. 1. ... 0. 0. 0.]]

 [[0. 0. 1. ... 1. 0. 0.]
  [0. 0. 0. ... 0. 0. 1.]
  [0. 0. 1. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 1.]
  [1. 0. 0. ... 1. 0. 0.]
  [0. 0. 1. ... 0. 0. 0.]]]

查看 original0.npy 檔案

1
2
3
print(f'original0.npy')
file_path = os.path.join(output_dir_path, 'original0.npy')
original0 = check_dataset(file_path)
執行結果參考
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
original0.npy
Type: <class 'numpy.ndarray'>, Shape: (500, 14, 100)
[[[1.1403443  1.1419119  1.1434577  ... 1.2344646  1.2355565  1.2366652 ]
  [0.44503313 0.44048095 0.43586108 ... 0.09895922 0.09892198 0.09881818]
  [1.1080282  1.1087638  1.1094663  ... 1.2303919  1.2330028  1.2356437 ]
  ...
  [2.0913386  2.0889924  2.0888383  ... 1.9993011  1.9984908  1.9976755 ]
  [0.73612225 0.7470203  0.7525134  ... 0.8020001  0.80041975 0.7988277 ]
  [1.4327983  1.4327874  1.4329102  ... 1.4458942  1.4462585  1.4466181 ]]

 [[0.7846492  0.78152096 0.7783957  ... 0.7394441  0.73993295 0.7404261 ]
  [0.18229088 0.17586255 0.16937983 ... 0.04270957 0.04123447 0.03970251]
  [1.2447149  1.245681   1.2466406  ... 1.2939184  1.2957157  1.2975751 ]
  ...
  [2.0702024  2.0702455  2.0702934  ... 1.9835824  1.9814789  1.9794396 ]
  [0.66228145 0.66188604 0.6615124  ... 0.74620795 0.745842   0.74548316]
  [1.4399137  1.4400566  1.4401985  ... 1.4306774  1.4308321  1.430989  ]]

 [[1.1801645  1.1829025  1.185635   ... 1.353138   1.3533303  1.3534759 ]
  [0.42506483 0.42107233 0.41697702 ... 0.08018772 0.08011684 0.08016345]
  [1.0592788  1.060556   1.061817   ... 1.1660019  1.16614    1.1663134 ]
  ...
  [2.0391326  2.0405219  2.0418937  ... 2.0199184  2.0142357  2.01137   ]
  [0.7274884  0.7268261  0.72616607 ... 0.7510811  0.7506006  0.75041914]
  [1.4546845  1.4547614  1.4548372  ... 1.4172906  1.4174985  1.4176484 ]]

 ...

 [[1.041035   1.0441163  1.0472333  ... 1.2278599  1.2281811  1.2284214 ]
  [0.34285313 0.33905002 0.3351447  ... 0.02366458 0.02423722 0.02502784]
  [0.872797   0.87108004 0.86942434 ... 0.8618627  0.86235774 0.8627792 ]
  ...
  [2.2161305  2.2242827  2.2318084  ... 1.7459784  1.7671947  1.7800756 ]
  [1.0148592  1.0114378  1.0082384  ... 0.71554226 0.71559244 0.71584934]
  [1.314138   1.314857   1.3155673  ... 1.4610794  1.4615589  1.4618688 ]]

 [[0.8292507  0.8287961  0.8283144  ... 0.817418   0.8179071  0.8183963 ]
  [0.03687635 0.03981134 0.04268485 ... 0.05158203 0.05035937 0.0491191 ]
  [0.79548436 0.7929163  0.790848   ... 0.7809997  0.7824856  0.7840376 ]
  ...
  [2.2418637  2.240296   2.1343472  ... 1.9140354  1.9133472  1.9098523 ]
  [0.7737858  0.77333015 0.78193784 ... 0.7994397  0.7997823  0.79999524]
  [1.3916144  1.3917903  1.3927475  ... 1.4365767  1.4376484  1.4388508 ]]

 [[0.891903   0.8903191  0.8887091  ... 0.86425775 0.86428493 0.8643113 ]
  [0.1616086  0.15369551 0.1457126  ... 0.08562256 0.085611   0.08560599]
  [1.1634592  1.1706057  1.1778103  ... 1.2142228  1.2142922  1.2143536 ]
  ...
  [1.6939791  1.6989081  1.7037703  ... 2.0460596  2.0468404  2.0475001 ]
  [0.77711374 0.7763163  0.77555615 ... 0.73519385 0.7354764  0.73573035]
  [1.2510566  1.2510723  1.251081   ... 1.4290721  1.4289652  1.4288527 ]]]

比對測試集

測試集矩陣資訊如下。

1
2
3
print(f'test_mujoco.npy')
file_path = r"D:\Code\sssd_cp_learning_and_testing\learning_and_testing\SSSD\datasets\Mujoco\test_mujoco.npy"
train_mujoco = check_dataset(file_path)
執行結果參考
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
test_mujoco.npy
Type: <class 'numpy.ndarray'>, Shape: (2000, 100, 14)
[[[1.1403443  0.44503313 1.1080282  ... 2.0913386  0.73612225 1.4327983 ]
  [1.1419119  0.44048095 1.1087638  ... 2.0889924  0.7470203  1.4327874 ]
  [1.1434577  0.43586108 1.1094663  ... 2.0888383  0.7525134  1.4329102 ]
  ...
  [1.2344646  0.09895922 1.2303919  ... 1.9993011  0.8020001  1.4458942 ]
  [1.2355565  0.09892198 1.2330028  ... 1.9984908  0.80041975 1.4462585 ]
  [1.2366652  0.09881818 1.2356437  ... 1.9976755  0.7988277  1.4466181 ]]

 [[0.7846492  0.18229088 1.2447149  ... 2.0702024  0.66228145 1.4399137 ]
  [0.78152096 0.17586255 1.245681   ... 2.0702455  0.66188604 1.4400566 ]
  [0.7783957  0.16937983 1.2466406  ... 2.0702934  0.6615124  1.4401985 ]
  ...
  [0.7394441  0.04270957 1.2939184  ... 1.9835824  0.74620795 1.4306774 ]
  [0.73993295 0.04123447 1.2957157  ... 1.9814789  0.745842   1.4308321 ]
  [0.7404261  0.03970251 1.2975751  ... 1.9794396  0.74548316 1.430989  ]]

 [[1.1801645  0.42506483 1.0592788  ... 2.0391326  0.7274884  1.4546845 ]
  [1.1829025  0.42107233 1.060556   ... 2.0405219  0.7268261  1.4547614 ]
  [1.185635   0.41697702 1.061817   ... 2.0418937  0.72616607 1.4548372 ]
  ...
  [1.353138   0.08018772 1.1660019  ... 2.0199184  0.7510811  1.4172906 ]
  [1.3533303  0.08011684 1.16614    ... 2.0142357  0.7506006  1.4174985 ]
  [1.3534759  0.08016345 1.1663134  ... 2.01137    0.75041914 1.4176484 ]]

 ...

 [[1.0392356  0.08581756 0.7099395  ... 2.0302138  0.79023373 1.3759874 ]
  [1.0393484  0.08554491 0.7099324  ... 2.0082736  0.7879161  1.3767979 ]
  [1.0394112  0.08544233 0.70994276 ... 1.9963381  0.7867339  1.377317  ]
  ...
  [1.0310189  0.08667419 0.6916089  ... 2.0218747  0.75082636 1.4282196 ]
  [1.030978   0.08667409 0.6915245  ... 2.0217166  0.75112975 1.4281985 ]
  [1.0309373  0.08667286 0.69144064 ... 2.0216012  0.7514324  1.4281791 ]]

 [[1.0522602  0.14923434 0.8988266  ... 2.4332573  0.7175308  1.4357938 ]
  [1.0528201  0.14891274 0.8963647  ... 2.4315226  0.71667176 1.4355559 ]
  [1.0533336  0.14862238 0.89461184 ... 2.2289753  0.72654146 1.4373304 ]
  ...
  [1.0292536  0.06743535 0.77494836 ... 2.0265324  0.74734867 1.4283221 ]
  [1.0285449  0.06419316 0.77352923 ... 2.0248575  0.74751395 1.4282281 ]
  [1.0278436  0.06086201 0.77209336 ... 2.023194   0.7475705  1.4281812 ]]

 [[0.8054271  0.38287213 0.7195578  ... 2.0518327  0.6804158  1.4449072 ]
  [0.8015366  0.38034502 0.71702343 ... 2.0528839  0.6800664  1.444928  ]
  [0.7976455  0.3777312  0.7144991  ... 2.053944   0.6797198  1.4449484 ]
  ...
  [0.4756081  0.11156489 0.5117482  ... 2.3870192  0.80451727 1.3886195 ]
  [0.47334316 0.11220536 0.50762016 ... 2.4114656  0.8274635  1.3821098 ]
  [0.47113514 0.11276194 0.50347656 ... 2.4220142  0.8363771  1.379714  ]]]

繪製預測圖

由 NRTSI 的論文中,我們可以得知 MuJoCo 的資料集包含 14 個維度的特徵值, 100 個時間點與 2,000 次實驗。

https://raw.githubusercontent.com/Josh-test-lab/website-assets-repository/refs/heads/main/posts/1140304%20meeting/mujoco%20dataset%20description.jpg
在 NRTSI 的論文中,對 MuJoCo 資料集的描述。

因此,我們可以將上述資料繪製如下。

https://raw.githubusercontent.com/Josh-test-lab/website-assets-repository/refs/heads/main/posts/1140304%20meeting/Figure_1.png

上圖繪製了第 10 次實驗的第 3 個特徵值,取前 100 筆資料。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def plot_experiment_feature(test_mujoco, original0, imputation0, n, f):
    """
    繪製第 n 次實驗的第 f 個特徵值隨時間變化,包含 test_mujoco、original0 和 imputation0。

    :param test_mujoco: (2000, 100, 14) 的數據集
    :param original0: (500, 14, 100) 的數據集,需要轉置為 (500, 100, 14)
    :param imputation0: (500, 14, 100) 的數據集,需要轉置為 (500, 100, 14)
    :param n: 指定實驗次數
    :param f: 指定特徵索引
    """

    # 轉置 original0 和 imputation0,使形狀變為 (500, 100, 14)
    original0 = original0.transpose(0, 2, 1)
    imputation0 = imputation0.transpose(0, 2, 1)

    if min(original0.shape[1], imputation0.shape[1]) < n:
        n = min(original0.shape[1], imputation0.shape[1])

    if min(original0.shape[2], imputation0.shape[2]) < f:
        f = min(original0.shape[2], imputation0.shape[2])

    time_steps = np.arange(100)

    plt.figure(figsize=(12, 6))
    plt.plot(time_steps, test_mujoco[n, :, f], label="Test Mujoco", alpha=0.7, linestyle='-')
    plt.plot(time_steps, original0[n, :, f], label="Original0", alpha=0.7, linestyle='dotted')
    plt.plot(time_steps, imputation0[n, :, f], label="Imputation0", alpha=0.7, linestyle='dashed')

    plt.xlabel("Time Step")
    plt.ylabel(f"Feature {f} Value")
    plt.title(f"Experiment {n} - Feature {f} Over Time")
    plt.legend()
    plt.show()

n = 10  # 第 10 次實驗
f = 3   # 第 3 個特徵
plot_experiment_feature(test_mujoco, original0, imputation0, n, f)

運行環境

  • 作業系統:Windows 11 24H2
  • 程式語言:Python 3.12.9

延伸學習

參考資料