When I recently work with pytorch Dataset
and DataLoader
, I wrote some code like this
import torch
class DatasetBase(torch.utils.data.Dataset):
args = {}
def __init__(self, args):
self.args.update(args)
def __getitem__(self, idx):
return 0
def __len__(self):
return 5
class Dataset(DatasetBase):
args = {'important_parameter': 1}
def __init__(self, args):
super(Dataset, self).__init__(args)
print('__init__', self.args)
def _worker_init_fn(self, *args):
print('worker init', self.args)
def get_dataloader(self):
return torch.utils.data.DataLoader(self, batch_size=1,
num_workers=3,
worker_init_fn=self._worker_init_fn,
pin_memory=True,
multiprocessing_context='spawn')
def main():
dataloader = Dataset({'important_parameter': 2}).get_dataloader()
for _ in dataloader:
pass
if __name__ == '__main__':
main()
In the code above, I try to have a default setting for args
in the Dataset
class, while sending customized arguments when I am actually using it. In this case, I want to pass {'important_parameter': 2}
to the dataset to overwrite the default {'important_parameter': 1}
. But unfortunately, the output of the code is:
__init__ {'important_parameter': 2}
worker init {'important_parameter': 1}
worker init {'important_parameter': 1}
worker init {'important_parameter': 1}
Turns out that the workers are not aware of the overwritten self.args
. I’ve also tried to move self.args.update(args)
from DatasetBase
into Dataset
and it doesn’t help. In the end though, I somehow (didn’t remember how) came up with this:
class DatasetBase(torch.utils.data.Dataset):
args = {}
def __init__(self, args):
self.args.update(args)
self.args = self.args
def __getitem__(self, idx):
return 0
def __len__(self):
return 5
Basically I added self.args = self.args
in the __init__
function of DatasetBase
. And suprisingly the output became
__init__ {'important_parameter': 2}
worker init {'important_parameter': 2}
worker init {'important_parameter': 2}
worker init {'important_parameter': 2}
Nailed it! I was confused in a second but soon realized that it might have something to do with static variables in Python. Before the line of self.args = self.args
, self.args
is still a static variable which means you can access it outside the class by using the class name directly DatasetBase.args
. But this magic line actually changes this situation. Furthermore, it also has something to do with multiprocessing_context='spawn'
, since I won’t need the magic line of self.args = self.args
to make the code work if I use other multiproccessing contextes.
It is good to know that this kind of thing exists. But I don’t have an in-depth explanation for this behavior. If you know why, please let me know in the comment!