classDiffusionTrainer:...def_train_per_epoch(self)->torch.Tensor:for(batch,)intqdm(self.dataloader):batch=batch.to(self.device)mask=self._update_mask(batch)loss_mask=~mask.bool()batch=batch.permute(0,2,1)assertbatch.size()==mask.size()==loss_mask.size()self.optimizer.zero_grad()loss=training_loss(model=self.net,loss_function=nn.MSELoss(),training_data=(batch,batch,mask,loss_mask),diffusion_parameters=self.diffusion_hyperparams,generate_only_missing=self.only_generate_missing,device=self.device,)loss.backward()self.optimizer.step()# autoFRK step# calculate loss from autoFRK# backward step#loss.backward()#self.optimizer.step()returnloss