WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

VisionDataModule set/get transform doesn't change datset transform #1064

@jascase901

Description

@jascase901

🐛 Bug

Setting the transform of the data module, should change the transform of the underlying dataset.

import pl_bolts                                                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                                          
from pl_bolts.datamodules import MNISTDataModule                                                                                                                                                                                                                                          
from torchvision import transforms as transform_lib                                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                                          
mnist = MNISTDataModule(data_dir = "/tmp/mnist")                                                                                                                                                                                                                                          
mnist.prepare_data()                                                                                                                                                                                                                                                                      
mnist.setup(stage="fit")                                                                                                                                                                                                                                                                  
                                                                                                                                                                                                                                                                                          
print("before set_transform")                                                                                                                                                                                                                                                             
print(mnist.dataset_train.dataset.transforms)                                                                                                                                                                                                                                             
#                                                                                                                                                                                                                                                                                         
#                                                                                                                                                                                                                                                                                         
# Expect this to change the train dataset transform?                                                                                                                                                                                                                                      
mnist.train_transforms = transform_lib.Compose(                                                                                                                                                                                                                                           
    [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.6,), std=(0.5,))]                                                                                                                                                                                                          
)                                                                                                                                                                                                                                                                                         
#                                                                                                                                                                                                                                                                                         
# expect to print the new transform                                                                                                                                                                                                                                                       
print("after transform")                                                                                                                                                                                                                                                                  
print(mnist.dataset_train.dataset.transforms) 

Results

before set_transform                                                                                                                                                                                                                                                                      
StandardTransform                                                                                                                                                                                                                                                                         
Transform: Compose(                                                                                                                                                                                                                                                                       
               ToTensor()                                                                                                                                                                                                                                                                 
           )                                                                                                                                                                                                                                                                              
after transform                                                                                                                                                                                                                                                                           
StandardTransform                                                                                                                                                                                                                                                                         
Transform: Compose(                                                                                                                                                                                                                                                                       
               ToTensor()                                                                                                                                                                                                                                                                 
           )       

Expected

I expected the datset transform to differ after I set the transform

Environment

  • PyTorch Version (e.g., 1.0):1,13.1+c117
  • OS (e.g., Linux):linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.10
  • CUDA/cuDNN version: 11

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions