When I recently work with pytorch Dataset and DataLoader, I wrote some code like this

import torch


class DatasetBase(torch.utils.data.Dataset):
    args = {}
    def __init__(self, args):
        self.args.update(args)

    def __getitem__(self, idx):
        return 0

    def __len__(self):
        return 5


class Dataset(DatasetBase):
    args = {'important_parameter': 1}
    def __init__(self, args):
        super(Dataset, self).__init__(args)
        print('__init__', self.args)

    def _worker_init_fn(self, *args):
        print('worker init', self.args)

    def get_dataloader(self):
        return torch.utils.data.DataLoader(self, batch_size=1,
                                         num_workers=3,
                                         worker_init_fn=self._worker_init_fn,
                                         pin_memory=True,
                                         multiprocessing_context='spawn')


def main():
    dataloader = Dataset({'important_parameter': 2}).get_dataloader()
    for _ in dataloader:
        pass


if __name__ == '__main__':
    main()

In the code above, I try to have a default setting for args in the Dataset class, while sending customized arguments when I am actually using it. In this case, I want to pass {'important_parameter': 2} to the dataset to overwrite the default {'important_parameter': 1}. But unfortunately, the output of the code is:

__init__ {'important_parameter': 2}
worker init {'important_parameter': 1}
worker init {'important_parameter': 1}
worker init {'important_parameter': 1}

Turns out that the workers are not aware of the overwritten self.args. I’ve also tried to move self.args.update(args) from DatasetBase into Dataset and it doesn’t help. In the end though, I somehow (didn’t remember how) came up with this:

class DatasetBase(torch.utils.data.Dataset):
    args = {}
    def __init__(self, args):
        self.args.update(args)
        self.args = self.args

    def __getitem__(self, idx):
        return 0

    def __len__(self):
        return 5

Basically I added self.args = self.args in the __init__ function of DatasetBase. And suprisingly the output became

__init__ {'important_parameter': 2}
worker init {'important_parameter': 2}
worker init {'important_parameter': 2}
worker init {'important_parameter': 2}

Nailed it! I was confused in a second but soon realized that it might have something to do with static variables in Python. Before the line of self.args = self.args, self.args is still a static variable which means you can access it outside the class by using the class name directly DatasetBase.args. But this magic line actually changes this situation. Furthermore, it also has something to do with multiprocessing_context='spawn', since I won’t need the magic line of self.args = self.args to make the code work if I use other multiproccessing contextes.

It is good to know that this kind of thing exists. But I don’t have an in-depth explanation for this behavior. If you know why, please let me know in the comment!