Shortcuts

torch.nn.utils.stateless.functional_call

torch.nn.utils.stateless.functional_call(module, parameters_and_buffers, args, kwargs=None, *, tie_weights=True, strict=False)[source]

Performs a functional call on the module by replacing the module parameters and buffers with the provided ones.

Warning

This API is deprecated as of PyTorch 2.0 and will be removed in a future version of PyTorch. Please use torch.func.functional_call() instead, which is a drop-in replacement for this API.

Note

If the module has active parametrizations, passing a value in the parameters_and_buffers argument with the name set to the regular parameter name will completely disable the parametrization. If you want to apply the parametrization function to the value passed please set the key as {submodule_name}.parametrizations.{parameter_name}.original.

Note

If the module performs in-place operations on parameters/buffers, these will be reflected in the parameters_and_buffers input.

Example:

>>> a = {'foo': torch.zeros(())}
>>> mod = Foo()  # does self.foo = self.foo + 1
>>> print(mod.foo)  # tensor(0.)
>>> functional_call(mod, a, torch.ones(()))
>>> print(mod.foo)  # tensor(0.)
>>> print(a['foo'])  # tensor(1.)

Note

If the module has tied weights, whether or not functional_call respects the tying is determined by the tie_weights flag.

Example:

>>> a = {'foo': torch.zeros(())}
>>> mod = Foo()  # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
>>> print(mod.foo)  # tensor(1.)
>>> mod(torch.zeros(()))  # tensor(2.)
>>> functional_call(mod, a, torch.zeros(()))  # tensor(0.) since it will change self.foo_tied too
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False)  # tensor(1.)--self.foo_tied is not updated
>>> new_a = {'foo', torch.zeros(()), 'foo_tied': torch.zeros(())}
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
Parameters:
  • module (torch.nn.Module) – the module to call

  • parameters_and_buffers (dict of str and Tensor) – the parameters that will be used in the module call.

  • args (Any or tuple) – arguments to be passed to the module call. If not a tuple, considered a single argument.

  • kwargs (dict) – keyword arguments to be passed to the module call

  • tie_weights (bool, optional) – If True, then parameters and buffers tied in the original model will be treated as tied in the reparamaterized version. Therefore, if True and different values are passed for the tied parameters and buffers, it will error. If False, it will not respect the originally tied parameters and buffers unless the values passed for both weights are the same. Default: True.

  • strict (bool, optional) – If True, then the parameters and buffers passed in must match the parameters and buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will error. Default: False.

Returns:

the result of calling module.

Return type:

Any

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources