• Docs >
  • Tensor Parallelism - torch.distributed.tensor.parallel
Shortcuts

Tensor Parallelism - torch.distributed.tensor.parallel

Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor (DTensor) and provides several parallelism styles: Rowwise, Colwise and Pairwise Parallelism.

Warning

Tensor Parallelism APIs are experimental and subject to change.

The entrypoint to parallelize your nn.Module using Tensor Parallelism is:

torch.distributed.tensor.parallel.parallelize_module(module, device_mesh, parallelize_plan, tp_mesh_dim=0)[source]

The API to apply Tensor Parallelism (TP) in PyTorch. We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains ParallelStyle, which indicates how user wants the module or sub_module to be parallelized.

User can also specify different parallel style per module fully qualified name (FQN). The API supports 2D parallelism natively by accepting an n-dimension device_mesh and users just need to specify the dimension where we perform tensor parallelism on.

Parameters:
  • module (nn.Module) – Module to be parallelized.

  • device_mesh (DeviceMesh) – Object which describes the mesh topology of devices for the DTensor.

  • parallelize_plan (Union[ParallelStyle, Dict[str, ParallelStyle]]) – The plan used to parallelize the module. It can be either a ParallelStyle object which contains how we prepare input/output for Tensor Parallelism or it can be a dict of module FQN and its corresponding ParallelStyle object.

  • tp_mesh_dim (int) – The dimension of device_mesh where we perform Tensor Parallelism on.

Returns:

A nn.Module object parallelized.

Return type:

Module

Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel
>>>
>>> # Define the module.
>>> m = Model(...)
>>> m = parallelize_module(m, PairwiseParallel())
>>>

Warning

PairwiseParallel comes with constraints for now. If you need finer granularity, you need to pass in a dict of module FQN and parallel style instead.

Tensor Parallelism supports the following parallel styles:

class torch.distributed.tensor.parallel.style.RowwiseParallel[source]

Partitioning the row of a module. We assume the input to be a sharded DTensor and output to be a replicated DTensor.

class torch.distributed.tensor.parallel.style.ColwiseParallel[source]

Partitioning the column of a tensor or module. We assume the input to be a replicated DTensor and output to be a sharded DTensor.

class torch.distributed.tensor.parallel.style.PairwiseParallel(_prepare_input=None, _prepare_output=None)[source]

PairwiseParallel concatenate colwise and rowwise styles as a fixed pair like what Megatron-LM(https://arxiv.org/abs/1909.08053) is doing. We assume both input and output need to be replicate DTensors.

Warning

PairwiseParallel only supports nn.Multihead Attention, nn.Transformer or even-number-layer MLP for now.

Warning

Sequence Parallelism are still in experimental and no evaluation has been done.

class torch.distributed.tensor.parallel.style.PairwiseSequenceParallel[source]

PairwiseSequenceParallel concatenate colwise and rowwise styles as a fixed pair together with sequence parallel like what Megatron-LM Sequence parallel (https://arxiv.org/pdf/2205.05198.pdf) is doing. We assume both input and output need to be sharded DTensors.

Warning

PairwiseSequenceParallel only supports nn.Multihead Attention, nn.Transformer or even-number-layer MLP for now.

Since Tensor Parallelism is built on top of DTensor, we need to specify the input and output placement of the module with DTensors so it can expectedly interacts with the module before and after. The followings are functions used for input/output preparation:

torch.distributed.tensor.parallel.style.make_input_replicate_1d(input, device_mesh=None)[source]

Replicate input tensor over an 1-D device mesh. This function will be used in ParallelStyle.

Parameters:
  • input (Union[torch.Tensor, DTensor]) – This input tensor will be replicated over the 1-D DeviceMesh.

  • device_mesh (DeviceMesh, optional) – The 1-D device mesh where input will be replicated. If no DeviceMesh is passed and input is a DTensor, input.device_mesh will be used. If DeviceMesh is not 1-D, an exception will be thrown. Default: None

Returns:

A DTensor replicated over device_mesh.

Return type:

DTensor

torch.distributed.tensor.parallel.style.make_input_reshard_replicate(input, device_mesh)[source]

To construct a Sharded DTensor from a tensor on different ranks and then convert to a replicate DTensor.

Parameters:
  • input (torch.Tensor) – The input tensor on each rank which consists of a global DTensor sharded on dimension 0 over the 1-D DeviceMesh and then the sharded DTensor is converted to a replicate DTensor.

  • device_mesh (DeviceMesh, optional) – The 1-D device mesh where input will be sharded. If DeviceMesh is not 1-D, an exception will be thrown. Default: None

Returns:

A DTensor sharded on dimension 0 over device_mesh

and then converted to replicate.

Return type:

DTensor

torch.distributed.tensor.parallel.style.make_input_shard_1d(input, device_mesh=None, dim=0)[source]

Shard input tensor on dim over an 1-D device mesh. This function will be used in ParallelStyle.

Parameters:
  • input (Union[torch.Tensor, DTensor]) – Single tensor will be sharded on dimension dim over the 1-D DeviceMesh.

  • device_mesh (DeviceMesh, optional) – The 1-D device mesh where input will be sharded. If no DeviceMesh is passed and input is a DTensor, input.device_mesh will be used. If DeviceMesh is not 1-D, an exception will be thrown. Default: None

  • dim (int, optional) – The sharding dimension of input tensor. Default: 0

Returns:

A DTensor sharded on dimension dim over device_mesh.

Return type:

DTensor

torch.distributed.tensor.parallel.style.make_input_shard_1d_last_dim(input, device_mesh=None)[source]

Wrapper func of make_input_shard_1d with dim = -1.

Parameters:
  • input (Union[torch.Tensor, DTensor]) – This single tensor will be sharded on the last dimension over the 1-D DeviceMesh.

  • device_mesh (DeviceMesh, optional) – The 1-D device mesh where input will be sharded. If no DeviceMesh is passed and input is a DTensor, input.device_mesh will be used. If DeviceMesh is not 1-D, an exception will be thrown. Default: None

Returns:

A DTensor sharded on the last dimension over device_mesh.

Return type:

DTensor

torch.distributed.tensor.parallel.style.make_output_replicate_1d(output, device_mesh=None)[source]

Convert Output DTensor to a replicated DTensor. This will be used in ParallelStyle.

Parameters:
  • output (DTensor) – Output of module to be converted.

  • device_mesh (DeviceMesh, optional) – Object needed to replicate the output and it needs to be a 1D device_mesh and we will throw exceptions if a non-1D device_mesh is passed in. If no device_mesh is passed in, we will reuse the one from output. Default: None

Returns:

A DTensor object made replicate.

Return type:

DTensor

torch.distributed.tensor.parallel.style.make_output_reshard_tensor(output, device_mesh=None)[source]

Convert Output DTensor to a sharded DTensor and return the local tensor.

Parameters:
  • output (DTensor) – Output of module to be converted.

  • device_mesh (DeviceMesh, optional) – Object needed to shard the output and it needs to be a 1D device_mesh and we will throw exceptions if a non-1D device_mesh is passed in. If no device_mesh is passed in, we will reuse the one from output. Default: None

Returns:

A torch.Tensor object converted from output DTensor.

Return type:

Tensor

torch.distributed.tensor.parallel.style.make_output_shard_1d(output, device_mesh=None, dim=0)[source]

Convert Output DTensor to a sharded DTensor. This will be used in ParallelStyle.

Parameters:
  • output (DTensor) – Output of module to be converted.

  • device_mesh (DeviceMesh, optional) – Object needed to shard the output and it needs to be a 1D device_mesh and we will throw exceptions if a non-1D device_mesh is passed in. If no device_mesh is passed in, we will reuse the one from output. Default: None

  • dim (int) – Sharding dim for output. Default: 0

Returns:

A DTensor object sharded on the given dim.

Return type:

DTensor

torch.distributed.tensor.parallel.style.make_output_tensor(output, device_mesh=None)[source]

Convert Output DTensor to a replicated DTensor first and then convert it to Tensor.

Parameters:
  • output (DTensor) – Output of module to be converted.

  • device_mesh (DeviceMesh, optional) – Object which is needed to replicate the output and it needs to be a 1D device_mesh and we will throw exceptions if a non-1D device_mesh is passed in. If no device_mesh is passed in, we will reuse the one from output. Default: None

Returns:

A torch.Tensor object converted from output DTensor.

Return type:

Tensor

Currently, there are some constraints which makes it hard for the nn.MultiheadAttention module to work out of box for Tensor Parallelism, so we built this multihead_attention module for Tensor Parallelism users. Also, in parallelize_module, we automatically swap nn.MultiheadAttention to this custom module when specifying PairwiseParallel.

class torch.distributed.tensor.parallel.multihead_attention_tp.TensorParallelMultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None, tp_size=1, self_attention=True)[source]

Multi-head Attention block from Transformer models. Since we need some customizations for the attention layer, we are writing a customized but mathematically equivalent attention module as defined in torch.nn.

Note that: We now only support the case when it’s self attention with limited input args and we also assume that the input tensor has a dimension of three. Although we do implement the logic for multihead attention, it was not fully tested.

We also enabled 2D parallelism to integrate with FullyShardedDataParallel. Users just need to call the following API explicitly:

torch.distributed.tensor.parallel.fsdp.enable_2d_with_fsdp()[source]

The API registers the extension which is needed for Tensor Parallelism (TP) to work with FullyShardedDataParallel (FSDP). We first parallelize parameters within one module or sub_modules based on a parallelize_plan and will let FSDP reshard the local tensor of distributed parameter which is essentially a DTensor.

Returns:

A bool indicated whether extension registration succeeds or not.

Return type:

bool

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources