Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Argument linking to link init_args to dict_kwargs. #375

Open
bot66 opened this issue Sep 11, 2023 · 8 comments
Open

Use Argument linking to link init_args to dict_kwargs. #375

bot66 opened this issue Sep 11, 2023 · 8 comments
Labels
enhancement New feature or request

Comments

@bot66
Copy link

bot66 commented Sep 11, 2023

🚀 Feature request

Use Argument linking to link init_args to dict_kwargs

Motivation

I try to link data.input_width to model.dict_kwargs.input_width because transformer models need input image shape for model initialization, but most CNNs don't need it, so have I a factory function to create both CNNs and Transformers, it uses kwargs to adapt both types of models, so I want to link dataset image shape information to model initialization kwargs.

When I use the code below

class CommandLineInterface(LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.link_arguments("data.input_width", "model.dict_kwargs.input_width")
        parser.link_arguments("data.input_height", "model.dict_kwargs.input_height")

It raises an error:

│                                                                                                  │
│   142 │   │   valid_target_leaf = self.target[1].dest == target                                  │
│   143 │   │   if not valid_target_leaf and is_target_subclass and not valid_target_init_arg:     │
│   144 │   │   │   prefix = self.target[1].dest+'.init_args.'                                     │
│ ❱ 145 │   │   │   raise ValueError(f'Target key expected to start with "{prefix}", got "{targe   │146 │   │                                                                                      │
│   147 │   │   # Replace target action with link action                                           │
│   148 │   │   if not is_target_subclass or valid_target_leaf:                                    │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Target key expected to start with "model.init_args.", got "model.dict_kwargs.input_width

Pitch

If we can randomly link init_args and dict_kwargs would be great, or is this feature already implemented?

Alternatives

@bot66 bot66 added the enhancement New feature or request label Sep 11, 2023
@mauvilsa
Copy link
Member

Does your model have init parameters input_width and input_height? If so, then the links should be with init_args as described in the error message, i.e.:

        parser.link_arguments("data.input_width", "model.init_args.input_width")
        parser.link_arguments("data.input_height", "model.init_args.input_height")

Any particular reason why you are using dict_kwargs instead of init_args? dict_kwargs is mostly for parameters that fail to resolve.

@bot66
Copy link
Author

bot66 commented Sep 11, 2023

Does your model have init parameters input_width and input_height? If so, then the links should be with init_args as described in the error message, i.e.:

        parser.link_arguments("data.input_width", "model.init_args.input_width")
        parser.link_arguments("data.input_height", "model.init_args.input_height")

Any particular reason why you are using dict_kwargs instead of init_args? dict_kwargs is mostly for parameters that fail to resolve.

Not all of the models I used have init parameters input_width and input_height, only for Vision Transformers.

The reason why I use dict_kwargs instead of init_args, It's because I use lightning.pytorch.LightningModule created a classification model wrapper, and inside that wrapper's __init__() method It has a factory function to create various backbone, different backbones have different init_args, some of the models have many arguments, e.g. SwinTransformer, So I use (*args, **kwargs) instead of explicit define them in the __init__(), and I use LightningCLI to manage configs, so I just wonder If I could link some of the arguments between LightningModule and LightningDataModule would be a more simple command line(as the example above).

class LitClassification(pl.LightningModule):
    """This is a PyTorch lightning wrapper for Image Classification tasks."""
    def __init__(
        self,
        num_classes: int,
        model_name: str = "resnet18",
        loss_name: str = "cross_entropy_loss",
        in_chans: int = 3,
        pretrained: bool = False,
        init_lr: float = 1e-3,
        weight_decay: float = 1e-2,
        augmentations: list[str] = [],
        *args,
        **kwargs
    ) -> None:
        super().__init__()
        self.model = create_model(model_name, in_chans, num_classes, pretrained, *args, **kwargs)
        self.loss_fn = create_loss(loss_name, *args, **kwargs)
        self.lr = init_lr
        self.weight_decay = weight_decay
        self.num_classes = num_classes
        self.augmentation = ImageAugmentation(augmentations)

        self.save_hyperparameters()


def create_model(
    model_name: str, in_chans: int, num_classes: int, pretrained: bool = False, *args, **kwargs
) -> nn.Module:
    """Factory function for creating models.

    Args:
        model_name (str): Model name.
        in_chans (int): Number of model input channels.
        num_classes (int): Number of model classes.
        pretrained (bool, optional): If `True`, load pretrained weight. Defaults to `False`.

    Returns:
        nn.Module: Created model.
    """
    return eval(model_name)(in_chans, num_classes, pretrained=pretrained, *args, **kwargs)

@bot66 bot66 changed the title Use Argument linking to link init_args to dict_kwargs. Use Argument linking to link init_args to dict_kwargs. Sep 11, 2023
@mauvilsa
Copy link
Member

The recommended way of implementing submodules (e.g. backbone, loss) is via dependency injection, see models-with-multiple-submodules.

Linking with target dict_kwargs is not a trivial change and is unlikely to be added any time soon. In part because the motivation is for a less recommended pattern.

Do you see viable to change your code to use dependency injection?

@bot66
Copy link
Author

bot66 commented Sep 11, 2023

I understand. Using dependency injection could not solve my problem, I want to create a simple short unified command line, but dependency injection would make the config.yaml and training command line more complicated.

Thanks for your advice anyway.

@mauvilsa
Copy link
Member

You could add input_width and input_height as parameters of LitClassification. And then inspect the signature of eval(model_name) and exclude input_width and input_height from the kwargs if that model does not accept them.

@mauvilsa
Copy link
Member

dependency injection would make the config.yaml and training command line more complicated

Why exactly do you say this? If dependency injection is used, from command line a class can be selected like --model.backbone={class_name}. And in a config file, it is also possible to select using just a string, like:

model:
  backbone: {class_name}

The automatically saved config indeed would be more complex because the submodule gets expended to a nested class_path and init_args with all settings. But to run the CLI is just as simple as what you are doing now.

@bot66
Copy link
Author

bot66 commented Sep 12, 2023

I'm building a model training API server, I want the request body (training config) to be as simple as possible because our users have no ML experience,and the request body structure is fixed, some extra arguments are under the dict_kwargs field, I parse the request body then build the training command line to start a training process.

So I think linking data.input_width to model.dict_kwargs.input_width will not break the request structure, and the user doesn't need to add them to the dict_kwargs field. Using dependency injection would make the "parse request to command line" process more complicated I guess, but I will try.

@bot66
Copy link
Author

bot66 commented Sep 12, 2023

Using dependency injection sounds great, I just need to write more data validation code when developing my API 🤔 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants