diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index b7079941..e12253e5 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -185,7 +185,7 @@ class ExtensibleTrainer(BaseModel): sort_key = opt_get(self.opt, ['train', 'sort_key'], None) if sort_key is not None: - sort_indices = torch.sort(data[sort_key]).indices + sort_indices = torch.sort(data[sort_key], descending=True).indices else: sort_indices = None