{ "cells": [ { "cell_type": "markdown", "id": "flush-flooring", "metadata": {}, "source": [ "## Megatron-LM Model Parallel Unit (MPU)\n", "---\n", "\n", "NVIDIA's Megatron-LM makes training very large language models ( up to one trillion parameters ) a reality. Megatron-LM's core, Model Parallel Unit ( MPU ), is the backbones for [DeepSpeed](https://www.deepspeed.ai/features/#model-parallelism) and Facebook [FairScale](https://github.com/facebookresearch/fairscale), with both of them integrating Megatron-LM's MPU in the backend.\n", "\n", "Before moving on to do model training, it is crucial to understand how Megatron-LM's core engine works, especially Tensor Model Parallelism.\n", "\n", "Tensor Model Parallelism is part of the MPU. When training Megatron-LM with multi-GPUs workload, for a smaller model, such as a one billion GPT3 model, applying Tensor Model Parallelism is enough. However, scaling the model size as well as the compute for a multi-nodes training job is not a trivial task, Pipeline Model Parallelism needs to be in place as well as Tensor Model Parallelism in order to achieve the near-linear scaling performance during model training runtime.\n", "\n", "In this notebook, we will focus solely on Tensor Model Parallelism, complemented with how Megatron-LM groups GPUs according to the training configuration. \n", "\n", "Note: It is **NOT** required to have 16 GPUs (i.e. 2 nodes, each with 8 x A100 GPUs), this will be done **virtually** in the notebook via modified code to print out the grouping of the GPUs affinity for the training configuration instead.\n", "\n", "## Learning Objectives\n", "\n", "The goal of this lab is to understand how Megatro-LM's MPU works. More specifically, we will cover :\n", "\n", " 1. GPUs grouping affinity for the training configuration.\n", " 2. Tensor Model Parallelism : \n", " - Column Parallel\n", " - Row Parallel\n", "\n", "Pipeline Model Parallelism will be covered in the lecture presentation instead. \n" ] }, { "cell_type": "markdown", "id": "drawn-advertiser", "metadata": {}, "source": [ "Before going through GPUs grouping affinity, we need to specify the following parameters :\n", "\n", "- p = Pipeline Model Parallel \n", "- t = Tensor Model Parallel\n", "- d = Data Parallel \n", "- n = Total number of GPUs used in the training\n", "\n", "Note that Megatron-LM requires p * t * d = n\n", "\n", "\n", "Let's first go through a default example given by Megatron-LM [initializer.py](https://github.com/NVIDIA/Megatron-LM/blob/90e0a0dd08159e1c95f4f9d99bb8687f327d36c3/megatron/mpu/initialize.py).\n", "\n", "The following training configuration is assumed : \n", "\n", " tensor_model_parallel_size_= 2 \n", "\n", " pipeline_model_parallel_size_= 4\n", "\n", "Let's say we have a total of 16 GPUs denoted by g0 ... g15, hence \n", "\n", " world_size = 16 \n", "\n", "According to Megatron-LM [initializer.py](https://github.com/NVIDIA/Megatron-LM/blob/90e0a0dd08159e1c95f4f9d99bb8687f327d36c3/megatron/mpu/initialize.py) we should see the following GPUs affinity grouping :\n", "\n", " 8 data_parallel groups:\n", " [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]\n", " 8 tensor model-parallel groups:\n", " [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]\n", " 4 pipeline model-parallel groups:\n", " [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]\n", "\n", "**Note** : for efficiency, the caller should make sure adjacent ranks are on the same DGX box.\n", "For example, if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and\n", "ranks 8 to 15 belong to the second box." ] }, { "cell_type": "markdown", "id": "suburban-exclusive", "metadata": {}, "source": [ "The code block below is modified from Megatron-LM [initializer.py](https://github.com/NVIDIA/Megatron-LM/blob/90e0a0dd08159e1c95f4f9d99bb8687f327d36c3/megatron/mpu/initialize.py) in order to avoid the need of actually having 16 physical GPUs to run this notebook, in another word, this notebook can be run without having any physical GPUs present." ] }, { "cell_type": "code", "execution_count": null, "id": "official-receptor", "metadata": {}, "outputs": [], "source": [ "import itertools\n", "def ensure_divisibility(numerator, denominator):\n", " \"\"\"Ensure that numerator is divisible by the denominator.\"\"\"\n", " assert numerator % denominator == 0, '{} is not divisible by {}'.format(\n", " numerator, denominator)\n", "def initialize_model_parallel(tensor_model_parallel_size_=2,\n", " pipeline_model_parallel_size_= 4,\n", " world_size=16):\n", " print(' ---------- world size is set to : {} ---------- '.format(world_size))\n", " print('> initializing tensor model parallel with size {}'.format(tensor_model_parallel_size_))\n", " print('> initializing pipeline model parallel with size {}'.format(pipeline_model_parallel_size_))\n", " \n", " tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)\n", " pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)\n", " \n", " # make sure world_size is divisible by t * p \n", " ensure_divisibility(world_size,tensor_model_parallel_size * pipeline_model_parallel_size)\n", " \n", " data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n", " print(\"> data parallel size is set to : \", data_parallel_size)\n", " num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size\n", " num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size\n", " num_data_parallel_groups = world_size // data_parallel_size\n", " print(\"---------- parallel groups ----------\")\n", " print(\"num_tensor_model_parallel_groups : \", num_tensor_model_parallel_groups)\n", " print(\"num_pipeline_model_parallel_groups : \", num_pipeline_model_parallel_groups)\n", " print(\"num_data_parallel_groups : \",num_data_parallel_groups )\n", " # Build the data-parallel groups.\n", " _DATA_PARALLEL_GROUP = []\n", " _MODEL_PARALLEL_GROUP = []\n", " _TENSOR_MODEL_PARALLEL_GROUP = []\n", " _PIPE_MODEL_PARALLEL_GROUP=[] \n", " _MODEL_PARALLEL_GROUP=[]\n", " all_data_parallel_group_ranks = []\n", " for i in range(pipeline_model_parallel_size):\n", " start_rank = i * num_pipeline_model_parallel_groups\n", " end_rank = (i + 1) * num_pipeline_model_parallel_groups\n", " #print(\"start rank : {} | end rank :{}\".format(start_rank, end_rank))\n", " temp=[]\n", " for j in range(tensor_model_parallel_size):\n", " ranks = range(start_rank + j, end_rank,\n", " tensor_model_parallel_size)\n", " temp.append(list(ranks))\n", " all_data_parallel_group_ranks.append(list(ranks))\n", " _DATA_PARALLEL_GROUP=all_data_parallel_group_ranks\n", "\n", " for i in range(num_pipeline_model_parallel_groups):\n", " ranks = range(i, world_size,\n", " num_pipeline_model_parallel_groups) \n", " _PIPE_MODEL_PARALLEL_GROUP.append(list(ranks))\n", " \n", " \n", " for i in range(data_parallel_size):\n", " ranks = [data_parallel_group_ranks[i]\n", " for data_parallel_group_ranks in all_data_parallel_group_ranks]\n", " _MODEL_PARALLEL_GROUP.append(ranks)\n", " \n", " for i in range(num_tensor_model_parallel_groups):\n", " ranks = range(i * tensor_model_parallel_size,\n", " (i + 1) * tensor_model_parallel_size)\n", " _TENSOR_MODEL_PARALLEL_GROUP.append(list(ranks))\n", " print(\"-----\"*20)\n", " print(\"_DATA_PARALLEL_GROUP \\n :\", _DATA_PARALLEL_GROUP)\n", " print(\"-----\"*20)\n", " print(\"_TENSOR_MODEL_PARALLEL_GROUP \\n :\", _TENSOR_MODEL_PARALLEL_GROUP)\n", " print(\"-----\"*20)\n", " print(\"_PIPE_MODEL_PARALLEL_GROUP \\n :\", _PIPE_MODEL_PARALLEL_GROUP)\n", " print(\"-----\"*20)\n", " print(\"Total :{} full models being partitioned into :{} GPUs \".format(len(_MODEL_PARALLEL_GROUP),world_size))\n", " for idx, m in zip(range(len(_MODEL_PARALLEL_GROUP)),_MODEL_PARALLEL_GROUP):\n", " m=[str(l) for l in m]\n", " print(\"model {} : is partitioned into gpus :{}\".format(str(idx),','.join(m))) \n" ] }, { "cell_type": "markdown", "id": "complex-values", "metadata": {}, "source": [ "\n", "Sanity check - Verify, after running the below code cell, the result will match the comment inside of [megatron/mpu/initializer.py](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/initialize.py#L63).\n", "Since we have a total of 16 GPUs **virtually** denoted by g0 ... g15, we expect to see the below result:\n", "\n", " 8 data_parallel groups:\n", " [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]\n", " 8 tensor model-parallel groups:\n", " [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]\n", " 4 pipeline model-parallel groups:\n", " [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]\n" ] }, { "cell_type": "code", "execution_count": null, "id": "numerous-somewhere", "metadata": {}, "outputs": [], "source": [ "initialize_model_parallel(tensor_model_parallel_size_=2,\n", " pipeline_model_parallel_size_= 4,\n", " world_size=16)" ] }, { "cell_type": "markdown", "id": "australian-portal", "metadata": {}, "source": [ "Below is the expected outputs:\n", "\n", " ---------- world size is set to : 16 ---------- \n", " > initializing tensor model parallel with size 2\n", " > initializing pipeline model parallel with size 4\n", " > data parallel size is set to : 2\n", " ---------- parallel groups ----------\n", " num_tensor_model_parallel_groups : 8\n", " num_pipeline_model_parallel_groups : 4\n", " num_data_parallel_groups : 8\n", " ----------------------------------------------------------------------------------------------------\n", " _DATA_PARALLEL_GROUP \n", " : [[0, 2], [1, 3], [4, 6], [5, 7], [8, 10], [9, 11], [12, 14], [13, 15]]\n", " ----------------------------------------------------------------------------------------------------\n", " _TENSOR_MODEL_PARALLEL_GROUP \n", " : [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]]\n", " ----------------------------------------------------------------------------------------------------\n", " _PIPE_MODEL_PARALLEL_GROUP \n", " : [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]\n", " ----------------------------------------------------------------------------------------------------\n", " Total :2 full models being partitioned into :16 GPUs \n", " model 0 : is partitioned into gpus :0,1,4,5,8,9,12,13\n", " model 1 : is partitioned into gpus :2,3,6,7,10,11,14,15" ] }, { "cell_type": "markdown", "id": "impressive-filename", "metadata": {}, "source": [ "Try a different training configuration. What did you get ? " ] }, { "cell_type": "code", "execution_count": null, "id": "restricted-favorite", "metadata": {}, "outputs": [], "source": [ "## Assuming the world size is 16, that is, you have 16 GPUs\n", "tensor_model_parallel_size= # try a different tensor_model_parallel_size\n", "pipeline_model_parallel_size= # try a different pipeline_model_parallel_size\n", "world_size=16\n", "assert world_size%(tensor_model_parallel_size * pipeline_model_parallel_size)==0,'please make sure world_size is divisible by tensor_model_parallel_size * pipeline_model_parallel_size' \n", "initialize_model_parallel(tensor_model_parallel_size_=tensor_model_parallel_size,pipeline_model_parallel_size_= pipeline_model_parallel_size,world_size=world_size)" ] }, { "cell_type": "markdown", "id": "realistic-disco", "metadata": {}, "source": [ "----------------------------------------------------------------------\n", "Megatron-LM's Column Parallel is part of Megatron-LM's Tensor Model Parallelism. \n", "[ColumnParallel reference](https://github.com/NVIDIA/Megatron-LM/blob/90e0a0dd08159e1c95f4f9d99bb8687f327d36c3/megatron/mpu/layers.py#L201)\n", "![ColumnParallel](./Megatron-LM/pics/ColumnParallel.JPG)" ] }, { "cell_type": "code", "execution_count": null, "id": "communist-uniform", "metadata": {}, "outputs": [], "source": [ "## Below class is modified from the original Megatron repo in order to skip environment variables initialization\n", "\n", "import sys\n", "sys.path.append(\"./Megatron-LM\")\n", "from megatron.mpu import layers\n", "from torch.nn.parameter import Parameter\n", "import torch.nn.init as init\n", "import torch\n", "import random\n", "from megatron import *\n", "from megatron.mpu.tests import *\n", "from megatron.mpu.utils import *\n", "global world_size \n", "world_size = 16\n", "class myColumnParallelLinear(torch.nn.Module):\n", " \"\"\"Linear layer with column parallelism.\n", " The linear layer is defined as Y = XA + b. A is parallelized along\n", " its second dimension as A = [A_1, ..., A_p].\n", " Arguments:\n", " input_size: first dimension of matrix A.\n", " output_size: second dimension of matrix A.\n", " bias: If true, add bias\n", " gather_output: If true, call all-gether on output and make Y available\n", " to all GPUs, otherwise, every GPU will have its output\n", " which is Y_i = XA_i\n", " init_method: method to initialize weights. Note that bias is always set\n", " to zero.\n", " stride: For the strided linear layers.\n", " keep_master_weight_for_test: This was added for testing and should be\n", " set to False. It returns the master weights\n", " used for initialization.\n", " skip_bias_add: This was added to enable performance optimations where bias\n", " can be fused with other elementwise operations. we skip \n", " adding bias but instead return it.\n", " \"\"\"\n", "\n", " def __init__(self, input_size, output_size, bias=True, gather_output=True,\n", " init_method=init.xavier_normal_, stride=1,\n", " keep_master_weight_for_test=False,\n", " skip_bias_add=False):\n", " super(myColumnParallelLinear, self).__init__()\n", "\n", " # Keep input parameters\n", " self.input_size = input_size\n", " self.output_size = output_size\n", " self.gather_output = gather_output\n", " # Divide the weight matrix along the last dimension.\n", " \n", " self.output_size_per_partition = divide(output_size, world_size)\n", " self.skip_bias_add = skip_bias_add\n", "\n", " # Parameters.\n", " # Note: torch.nn.functional.linear performs XA^T + b and as a result\n", " # we allocate the transpose.\n", " # Initialize weight. \n", " use_cpu_initialization=True # hard coded to use cpu\n", " params_dtype = torch.float # skipping need of args\n", " \n", " if use_cpu_initialization:\n", " self.weight = Parameter(torch.empty(self.output_size_per_partition,\n", " self.input_size,\n", " dtype=params_dtype))\n", " \n", " self.master_weight = m_initialize_affine_weight_cpu(\n", " self.weight, self.output_size, self.input_size,\n", " self.output_size_per_partition, 0, init_method,\n", " stride=stride, return_master_weight=keep_master_weight_for_test)\n", " \n", " else:\n", " self.weight = Parameter(torch.empty(\n", " self.output_size_per_partition, self.input_size,\n", " device=torch.cuda.current_device(), dtype=params_dtype))\n", " _initialize_affine_weight_gpu(self.weight, init_method,\n", " partition_dim=0, stride=stride)\n", " \n", " if bias:\n", " if use_cpu_initialization:\n", " self.bias = Parameter(torch.empty(\n", " self.output_size_per_partition, dtype=params_dtype))\n", " else:\n", " self.bias = Parameter(torch.empty(\n", " self.output_size_per_partition,\n", " device=torch.cuda.current_device(),\n", " dtype=params_dtype))\n", " # Always initialize bias to zero.\n", " with torch.no_grad():\n", " self.bias.zero_()\n", " else:\n", " self.register_parameter('bias', None)\n", "\n", " def forward(self, input_):\n", " # Set up backprop all-reduce.\n", " print(\"in Column parallel forward\")\n", " input_parallel = copy_to_tensor_model_parallel_region(input_)\n", " # Matrix multiply.\n", "\n", " bias = self.bias if not self.skip_bias_add else None\n", " output_parallel = F.linear(input_parallel, self.weight, bias)\n", " if self.gather_output:\n", " # All-gather across the partitions.\n", " output = gather_from_tensor_model_parallel_region(output_parallel)\n", " else:\n", " output = output_parallel \n", " output_bias = self.bias if self.skip_bias_add else None\n", " return output, output_bias" ] }, { "cell_type": "code", "execution_count": null, "id": "tired-moment", "metadata": {}, "outputs": [], "source": [ "def get_weight_list(master_weight,tensor_model_parallel_gp):\n", " my_weight_list=[]\n", " a,b=master_weight.size()\n", " print(\"A = [\")\n", " tensor_model_parallel_gp=list(itertools.chain(*tensor_model_parallel_gp))\n", " cnt=0\n", " for gp in tensor_model_parallel_gp :\n", " if which_model_parallel=='col': \n", " temp=master_weight[gp::world_size].T\n", " if cnt < world_size -1 : \n", " print(\"A{}=\".format(str(cnt)), temp.size(), end = ',')\n", " else:\n", " print(\"A{}=\".format(str(cnt)), temp.size())\n", " elif which_model_parallel =='row':\n", " temp=master_weight.T\n", " temp=temp[gp::world_size]\n", " if cnt < world_size -1 :\n", " print(\"A{}=\".format(str(cnt)), temp.size(),',')\n", " else:\n", " print(\"A{}=\".format(str(cnt)), temp.size())\n", "\n", " else:\n", " print(\"set which_model_parallel to **col** or **row**\")\n", " cnt+=1 \n", " my_weight_list.append(temp)\n", " \n", " print(\" ]\")\n", " print(len(my_weight_list))\n", " return my_weight_list\n", "def m_initialize_affine_weight_cpu(weight, output_size, input_size,\n", " per_partition_size, partition_dim,\n", " init_method, stride=1,\n", " return_master_weight=False):\n", " \"\"\"Initialize affine weight for model parallel.\n", " Build the master weight on all processes and scatter\n", " the relevant chunk.\"\"\"\n", " params_dtype = torch.float\n", " # Initialize master weight\n", " master_weight = torch.empty(output_size, input_size,\n", " dtype=torch.float,\n", " requires_grad=False) \n", " \n", " master_weight = master_weight.to(dtype=params_dtype)\n", " # Split and copy\n", " per_partition_per_stride_size = divide(per_partition_size, stride)\n", " print(\"per_partition_per_stride_size \",per_partition_per_stride_size)\n", " weight_list = torch.split(master_weight, per_partition_per_stride_size,\n", " dim=partition_dim)\n", " ######## tensor_model_parallel_gp below is hard-coded for tensor_model_parallel_size= 2 , pipeline_model_parallel_size= 4 ########\n", " ######## if you use other model parallel configuration , please copy and replace it in tensor_model_parallel_gp ########\n", " tensor_model_parallel_gp=[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]] \n", " my_weight_list = get_weight_list(master_weight,tensor_model_parallel_gp)\n", " \n", " with torch.no_grad():\n", " torch.cat(my_weight_list, dim=partition_dim, out=weight)\n", " if return_master_weight:\n", " return master_weight\n", " return None" ] }, { "cell_type": "markdown", "id": "pediatric-february", "metadata": {}, "source": [ "Peek inside Column Parallel Class, indeed Column Parallel can partition input matrix A into [A0, A1, A2 ...An] each of the partition matrix Ai, where i= 0, 1, 2 ...n, is sliced column-wised. Each column-wised partitioned matrix Ai will then be sent to the corresponding GPU based on the grouped GPU affinity. However, since we do not physically have 16 GPUs and therefore are using the modified code block above, we will steer clear of the need to allocate actual physical GPUs as resources. We will therefore only print out the shape of each column-wised partitioned matrix Ai, and will NOT be sending the partition matrix Ai to any GPUs." ] }, { "cell_type": "code", "execution_count": null, "id": "competitive-honor", "metadata": {}, "outputs": [], "source": [ "tensor_model_parallel_size= 2 \n", "pipeline_model_parallel_size= 4 \n", "input_size = 1024 # 1024 rows\n", "output_size = 512 # 256 columns\n", "which_model_parallel='col'\n", "print(\"this is how A is sliced column-wised ...\\n\")\n", "testCol=myColumnParallelLinear(input_size, output_size, bias=True, gather_output=True,\n", " init_method=init.xavier_normal_, stride=1,\n", " keep_master_weight_for_test=False,\n", " skip_bias_add=False)" ] }, { "cell_type": "markdown", "id": "manual-roman", "metadata": {}, "source": [ "Below is the expected outputs :\n", "\n", " this is how A is sliced column-wised ...\n", " per_partition_per_stride_size 32\n", " A = [\n", " A0= torch.Size([1024, 32]),A1= torch.Size([1024, 32]),A2= torch.Size([1024, 32]),A3= torch.Size([1024, 32]),A4= torch.Size([1024, 32]),A5= torch.Size([1024, 32]),A6= torch.Size([1024, 32]),A7= torch.Size([1024, 32]),A8= torch.Size([1024, 32]),A9= torch.Size([1024, 32]),A10= torch.Size([1024, 32]),A11= torch.Size([1024, 32]),A12= torch.Size([1024, 32]),A13= torch.Size([1024, 32]),A14= torch.Size([1024, 32]),A15= torch.Size([1024, 32])\n", " ]\n", " 16" ] }, { "cell_type": "code", "execution_count": null, "id": "soviet-turner", "metadata": {}, "outputs": [], "source": [ "per_partition_per_stride_size=32\n", "assert 16* per_partition_per_stride_size == 512 " ] }, { "cell_type": "code", "execution_count": null, "id": "alternate-glenn", "metadata": {}, "outputs": [], "source": [ "type(testCol)" ] }, { "cell_type": "markdown", "id": "impossible-orbit", "metadata": {}, "source": [ "Below is the expected outputs :\n", " \n", " __main__.myColumnParallelLinear" ] }, { "cell_type": "code", "execution_count": null, "id": "flexible-cooper", "metadata": {}, "outputs": [], "source": [ "testCol.input_size, testCol.output_size" ] }, { "cell_type": "markdown", "id": "broadband-producer", "metadata": {}, "source": [ "Below is the expected outputs :\n", " \n", " (1024, 512)" ] }, { "cell_type": "markdown", "id": "united-watch", "metadata": {}, "source": [ "----------------------------------------------------------------------\n", "Megatron-LM's Row Parallel is part of Megatron-LM's Tensor Model Parallelism. \n", "[RowParallel reference](https://github.com/NVIDIA/Megatron-LM/blob/90e0a0dd08159e1c95f4f9d99bb8687f327d36c3/megatron/mpu/layers.py#L294)\n", "![RowParallel](./Megatron-LM/pics/RowParallel.JPG)" ] }, { "cell_type": "code", "execution_count": null, "id": "silver-roommate", "metadata": {}, "outputs": [], "source": [ "## Below class is modified from the original Megatron repo in order to skip environment variable initialization\n", "\n", "class myRowParallelLinear(torch.nn.Module):\n", " \"\"\"Linear layer with row parallelism.\n", " The linear layer is defined as Y = XA + b. A is parallelized along\n", " its first dimension and X along its second dimension as:\n", " - -\n", " | A_1 |\n", " | . |\n", " A = | . | X = [X_1, ..., X_p]\n", " | . |\n", " | A_p |\n", " - -\n", " Arguments:\n", " input_size: first dimension of matrix A.\n", " output_size: second dimension of matrix A.\n", " bias: If true, add bias. Note that bias is not parallelized.\n", " input_is_parallel: If true, we assume that the input is already\n", " split across the GPUs and we do not split\n", " again.\n", " init_method: method to initialize weights. Note that bias is always set\n", " to zero.\n", " stride: For the strided linear layers.\n", " keep_master_weight_for_test: This was added for testing and should be\n", " set to False. It returns the master weights\n", " used for initialization.\n", " skip_bias_add: This was added to enable performance optimations where bias\n", " can be fused with other elementwise operations. we skip \n", " adding bias but instead return it.\n", " \"\"\"\n", "\n", " def __init__(self, input_size, output_size, bias=True,\n", " input_is_parallel=False,\n", " init_method=init.xavier_normal_, stride=1,\n", " keep_master_weight_for_test=False,\n", " skip_bias_add=False):\n", " super(myRowParallelLinear, self).__init__()\n", "\n", " # Keep input parameters\n", " self.input_size = input_size\n", " self.output_size = output_size\n", " self.input_is_parallel = input_is_parallel\n", " # Divide the weight matrix along the last dimension.\n", " self.input_size_per_partition = divide(input_size, world_size)\n", " self.skip_bias_add = skip_bias_add\n", " print(\"input_size_per_partition \", self.input_size_per_partition)\n", " \n", "\n", " # Parameters.\n", " # Note: torch.nn.functional.linear performs XA^T + b and as a result\n", " # we allocate the transpose.\n", " # Initialize weight.\n", " use_cpu_initialization=True # hard coded to use cpu\n", " params_dtype = torch.float # skipping need of args\n", " if use_cpu_initialization:\n", " self.weight = Parameter(torch.empty(self.output_size,\n", " self.input_size_per_partition,\n", " dtype=params_dtype))\n", " self.master_weight = m_initialize_affine_weight_cpu(\n", " self.weight, self.output_size, self.input_size,\n", " self.input_size_per_partition, 1, init_method,\n", " stride=stride, return_master_weight=keep_master_weight_for_test)\n", " else:\n", " self.weight = Parameter(torch.empty(\n", " self.output_size, self.input_size_per_partition,\n", " device=torch.cuda.current_device(), dtype=params_dtype))\n", " _initialize_affine_weight_gpu(self.weight, init_method,\n", " partition_dim=1, stride=stride)\n", " if bias:\n", " if use_cpu_initialization:\n", " self.bias = Parameter(torch.empty(self.output_size,\n", " dtype=params_dtype))\n", " else:\n", " self.bias = Parameter(torch.empty(\n", " self.output_size, device=torch.cuda.current_device(),\n", " dtype=params_dtype))\n", " # Always initialize bias to zero.\n", " with torch.no_grad():\n", " self.bias.zero_()\n", " else:\n", " self.register_parameter('bias', None)\n", "\n", "\n", "\n", " def forward(self, input_):\n", " # Set up backprop all-reduce.\n", " if self.input_is_parallel:\n", " input_parallel = input_\n", " else:\n", " input_parallel = scatter_to_tensor_model_parallel_region(input_)\n", " # Matrix multiply.\n", " output_parallel = F.linear(input_parallel, self.weight)\n", " # All-reduce across all the partitions.\n", " output_ = reduce_from_tensor_model_parallel_region(output_parallel)\n", " if not self.skip_bias_add:\n", " output = output_ + self.bias if self.bias is not None else output_\n", " output_bias = None\n", " else:\n", " output = output_\n", " output_bias = self.bias\n", " return output, output_bias" ] }, { "cell_type": "markdown", "id": "explicit-filing", "metadata": {}, "source": [ "Peek inside Row Parallel Class, indeed Row Parallel can partition input matrix A into [A0, A1, A2 ...An] each of the partition matrix Ai, where i= 0, 1, 2 ...n, Ai was chopped in a row-wised manner. Each row-wised partitioned matrix Ai will then be sent to the corresponding GPU based on the grouped GPU affinity. However, since we do not physically have 16 GPUs and therefore are using modified code block above, we will steer clear of the need to allocate actual physical GPUs as resources. We will therefore only print out the shape of each row-wised partitioned matrix Ai, and will NOT be sending the partition matrix Ai to any GPUs." ] }, { "cell_type": "code", "execution_count": null, "id": "whole-generic", "metadata": {}, "outputs": [], "source": [ "input_size = 1024 # first dimension of the matrix\n", "output_size = 512 # 2nd dimension of the matrix\n", "print(\"this is how A is sliced Row-wised ...\\n\")\n", "which_model_parallel='row'\n", "testRow=myRowParallelLinear(input_size,output_size, bias=True,\n", " input_is_parallel=False,\n", " init_method=init.xavier_normal_, stride=1,\n", " keep_master_weight_for_test=False,\n", " skip_bias_add=False)" ] }, { "cell_type": "markdown", "id": "invisible-chocolate", "metadata": {}, "source": [ "Below is the expected outputs :\n", "\n", " this is how A is sliced Row-wised ...\n", "\n", " input_size_per_partition 64\n", " per_partition_per_stride_size 64\n", " A = [\n", " A0= torch.Size([64, 512]) ,\n", " A1= torch.Size([64, 512]) ,\n", " A2= torch.Size([64, 512]) ,\n", " A3= torch.Size([64, 512]) ,\n", " A4= torch.Size([64, 512]) ,\n", " A5= torch.Size([64, 512]) ,\n", " A6= torch.Size([64, 512]) ,\n", " A7= torch.Size([64, 512]) ,\n", " A8= torch.Size([64, 512]) ,\n", " A9= torch.Size([64, 512]) ,\n", " A10= torch.Size([64, 512]) ,\n", " A11= torch.Size([64, 512]) ,\n", " A12= torch.Size([64, 512]) ,\n", " A13= torch.Size([64, 512]) ,\n", " A14= torch.Size([64, 512]) ,\n", " A15= torch.Size([64, 512])\n", " ]\n", " 16" ] }, { "cell_type": "code", "execution_count": null, "id": "challenging-chassis", "metadata": {}, "outputs": [], "source": [ "per_partition_per_stride_size=64\n", "assert 16* per_partition_per_stride_size == 1024 " ] }, { "cell_type": "code", "execution_count": null, "id": "national-physics", "metadata": {}, "outputs": [], "source": [ "testRow.input_size, testRow.output_size" ] }, { "cell_type": "markdown", "id": "olympic-probe", "metadata": {}, "source": [ "Below is the expected outputs :\n", " \n", " (1024, 512)" ] }, { "cell_type": "markdown", "id": "operating-lender", "metadata": {}, "source": [ "---\n", "\n", "## Links and Resources\n", "Don't forget to check out additional resources such as [Efficient Large-Scale Language Model Training on GPU Clusters](https://arxiv.org/pdf/2104.04473.pdf ) and [Pushing Forward the Frontiers of Natural Language Processing](https://blogs.nvidia.com/blog/2021/09/16/nlp-frontiers-ai-hardware-summit/)." ] }, { "cell_type": "markdown", "id": "literary-orchestra", "metadata": {}, "source": [ "-----\n", "##

HOME      NEXT

\n" ] }, { "cell_type": "markdown", "id": "passing-plaza", "metadata": {}, "source": [ "-----\n", "\n", "\n", "## Licensing \n", "\n", "This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }