{ "cells": [ { "cell_type": "markdown", "id": "wooden-perth", "metadata": {}, "source": [ "## Megatron-LM Model Parallel Unit (MPU)\n", "---\n", "\n", "NVIDIA's Megatron-LM makes training very large langauge models ( up to one trillion parameters ) a reality. Megatron-LM's core, Model Paralleism Unit ( MPU ), is the backbones for [DeepSpeed](https://www.deepspeed.ai/features/#model-parallelism) and Facebook [FairScale](https://github.com/facebookresearch/fairscale), both of them integrate Megatron-LM's MPU heavily in the backend.\n", "\n", "## Learning Objectives\n", "\n", "The goal of this lab is to understand how Megatro-LM's Model Parallel Unit (MPU) works, more specifically, we will cover :\n", "\n", " - GPUs grouping affinity per training configuration.\n", " - Tensor Parallism : \n", " - Column Parallel\n", " - Row Parallel\n", "\n", "Pipeline parallism will be covered in the lecture instead. \n" ] }, { "cell_type": "markdown", "id": "legislative-dividend", "metadata": {}, "source": [ "Before going through GPUs grouping affinity, we need to specify the following parameters :\n", "\n", "- p = Pipeline Model Parallel \n", "- t = Tensor Model Parallel\n", "- d = Data Parallal \n", "- n = Total number of GPUs used in the training\n", "\n", "Note that Megatron-LM requires p * t * d = n\n", "\n", "\n", "Let's first go through a default example given by Megatron-LM [initializer.py](https://github.com/NVIDIA/Megatron-LM/blob/90e0a0dd08159e1c95f4f9d99bb8687f327d36c3/megatron/mpu/initialize.py).\n", "\n", "The following training configuration is assumed : \n", "\n", " tensor_model_parallel_size_= 2 \n", "\n", " pipeline_model_parallel_size_= 4\n", "\n", "Let's say we have a total of 16 GPUs denoted by g0 ... g15, hence \n", "\n", " world_size = 16 \n", "\n", "Accoridng to Megatron-LM [initializer.py](https://github.com/NVIDIA/Megatron-LM/blob/90e0a0dd08159e1c95f4f9d99bb8687f327d36c3/megatron/mpu/initialize.py) we should see the following ...\n", "\n", " 8 data_parallel groups:\n", " [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]\n", " 8 tensor model-parallel groups:\n", " [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]\n", " 4 pipeline model-parallel groups:\n", " [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]\n", "\n", "**Note** that for efficiency, the caller should make sure adjacent ranks are on the same DGX box.\n", "For example if we are using 2 DGX-1 boxes\n", "with a total of 16 GPUs, rank 0 to 7 belong to the first box and\n", "ranks 8 to 15 belong to the second box." ] }, { "cell_type": "markdown", "id": "faced-detroit", "metadata": {}, "source": [ "The code block below is modified from Megatron-LM [initializer.py](https://github.com/NVIDIA/Megatron-LM/blob/90e0a0dd08159e1c95f4f9d99bb8687f327d36c3/megatron/mpu/initialize.py) in order to avoid the need of having actual 16 physical GPUs to run this notebook, in another word, this notebook can be run without have any physical GPUs present." ] }, { "cell_type": "code", "execution_count": null, "id": "cooperative-contractor", "metadata": {}, "outputs": [], "source": [ "import itertools\n", "def ensure_divisibility(numerator, denominator):\n", " \"\"\"Ensure that numerator is divisible by the denominator.\"\"\"\n", " assert numerator % denominator == 0, '{} is not divisible by {}'.format(\n", " numerator, denominator)\n", "def initialize_model_parallel(tensor_model_parallel_size_=2,\n", " pipeline_model_parallel_size_= 4,\n", " world_size=16):\n", " print(' ---------- world size is set to : {} ---------- '.format(world_size))\n", " print('> initializing tensor model parallel with size {}'.format(tensor_model_parallel_size_))\n", " print('> initializing pipeline model parallel with size {}'.format(pipeline_model_parallel_size_))\n", " \n", " tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)\n", " pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)\n", " \n", " # make sure world_size is divisible by t * p \n", " ensure_divisibility(world_size,tensor_model_parallel_size * pipeline_model_parallel_size)\n", " \n", " data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n", " print(\"> data parallel size is set to : \", data_parallel_size)\n", " num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size\n", " num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size\n", " num_data_parallel_groups = world_size // data_parallel_size\n", " print(\"---------- parallel groups ----------\")\n", " print(\"num_tensor_model_parallel_groups : \", num_tensor_model_parallel_groups)\n", " print(\"num_pipeline_model_parallel_groups : \", num_pipeline_model_parallel_groups)\n", " print(\"num_data_parallel_groups : \",num_data_parallel_groups )\n", " # Build the data-parallel groups.\n", " _DATA_PARALLEL_GROUP = []\n", " _MODEL_PARALLEL_GROUP = []\n", " _TENSOR_MODEL_PARALLEL_GROUP = []\n", " _PIPE_MODEL_PARALLEL_GROUP=[] \n", " _MODEL_PARALLEL_GROUP=[]\n", " all_data_parallel_group_ranks = []\n", " for i in range(pipeline_model_parallel_size):\n", " start_rank = i * num_pipeline_model_parallel_groups\n", " end_rank = (i + 1) * num_pipeline_model_parallel_groups\n", " #print(\"start rank : {} | end rank :{}\".format(start_rank, end_rank))\n", " temp=[]\n", " for j in range(tensor_model_parallel_size):\n", " ranks = range(start_rank + j, end_rank,\n", " tensor_model_parallel_size)\n", " temp.append(list(ranks))\n", " all_data_parallel_group_ranks.append(list(ranks))\n", " _DATA_PARALLEL_GROUP=all_data_parallel_group_ranks\n", "\n", " for i in range(num_pipeline_model_parallel_groups):\n", " ranks = range(i, world_size,\n", " num_pipeline_model_parallel_groups) \n", " _PIPE_MODEL_PARALLEL_GROUP.append(list(ranks))\n", " \n", " \n", " for i in range(data_parallel_size):\n", " ranks = [data_parallel_group_ranks[i]\n", " for data_parallel_group_ranks in all_data_parallel_group_ranks]\n", " _MODEL_PARALLEL_GROUP.append(ranks)\n", " \n", " for i in range(num_tensor_model_parallel_groups):\n", " ranks = range(i * tensor_model_parallel_size,\n", " (i + 1) * tensor_model_parallel_size)\n", " _TENSOR_MODEL_PARALLEL_GROUP.append(list(ranks))\n", " print(\"-----\"*20)\n", " print(\"_DATA_PARALLEL_GROUP \\n :\", _DATA_PARALLEL_GROUP)\n", " print(\"-----\"*20)\n", " print(\"_TENSOR_MODEL_PARALLEL_GROUP \\n :\", _TENSOR_MODEL_PARALLEL_GROUP)\n", " print(\"-----\"*20)\n", " print(\"_PIPE_MODEL_PARALLEL_GROUP \\n :\", _PIPE_MODEL_PARALLEL_GROUP)\n", " print(\"-----\"*20)\n", " print(\"Total :{} full models being partitioned into :{} GPUs \".format(len(_MODEL_PARALLEL_GROUP),world_size))\n", " for idx, m in zip(range(len(_MODEL_PARALLEL_GROUP)),_MODEL_PARALLEL_GROUP):\n", " m=[str(l) for l in m]\n", " print(\"model {} : is partitioned into gpus :{}\".format(str(idx),','.join(m))) \n" ] }, { "cell_type": "markdown", "id": "floppy-princeton", "metadata": {}, "source": [ "\n", "Sanity check, verify, after run the below code cell, the result will match the comment inside of [megatron/mpu/initializer.py](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/initialize.py#L63).\n", "Since we have a total of 16 GPUs denoted by g0 ... g15, we expected to see the below result :\n", "\n", " 8 data_parallel groups:\n", " [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]\n", " 8 tensor model-parallel groups:\n", " [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]\n", " 4 pipeline model-parallel groups:\n", " [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]\n" ] }, { "cell_type": "code", "execution_count": null, "id": "postal-hotel", "metadata": {}, "outputs": [], "source": [ "initialize_model_parallel(tensor_model_parallel_size_=2,\n", " pipeline_model_parallel_size_= 4,\n", " world_size=16)" ] }, { "cell_type": "markdown", "id": "electric-differential", "metadata": {}, "source": [ "Below is the expected outputs :\n", "\n", " ---------- world size is set to : 16 ---------- \n", " > initializing tensor model parallel with size 2\n", " > initializing pipeline model parallel with size 4\n", " > data parallel size is set to : 2\n", " ---------- parallel groups ----------\n", " num_tensor_model_parallel_groups : 8\n", " num_pipeline_model_parallel_groups : 4\n", " num_data_parallel_groups : 8\n", " ----------------------------------------------------------------------------------------------------\n", " _DATA_PARALLEL_GROUP \n", " : [[0, 2], [1, 3], [4, 6], [5, 7], [8, 10], [9, 11], [12, 14], [13, 15]]\n", " ----------------------------------------------------------------------------------------------------\n", " _TENSOR_MODEL_PARALLEL_GROUP \n", " : [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]]\n", " ----------------------------------------------------------------------------------------------------\n", " _PIPE_MODEL_PARALLEL_GROUP \n", " : [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]\n", " ----------------------------------------------------------------------------------------------------\n", " Total :2 full models being partitioned into :16 GPUs \n", " model 0 : is partitioned into gpus :0,1,4,5,8,9,12,13\n", " model 1 : is partitioned into gpus :2,3,6,7,10,11,14,15" ] }, { "cell_type": "markdown", "id": "prime-puppy", "metadata": {}, "source": [ "Try a different training configuration, what did you get ? " ] }, { "cell_type": "code", "execution_count": null, "id": "external-freedom", "metadata": {}, "outputs": [], "source": [ "## assuming the world size is 16, that is, you have 16 GPUs\n", "tensor_model_parallel_size= # try a different tensor_model_parallel_size_\n", "pipeline_model_parallel_size= # try a different pipeline_model_parallel_size_\n", "world_size=16\n", "assert world_size%(tensor_model_parallel_size * pipeline_model_parallel_size)==0,'please make sure world_size is divisible by tensor_model_parallel_size * pipeline_model_parallel_size' \n", "initialize_model_parallel(tensor_model_parallel_size_=tensor_model_parallel_size,pipeline_model_parallel_size_= pipeline_model_parallel_size,world_size=world_size)" ] }, { "cell_type": "markdown", "id": "every-cotton", "metadata": {}, "source": [ "----------------------------------------------------------------------\n", "Column Parallel is part of Megatron-LM's Tensor Parallelism. \n", "[ColumnParallel reference](https://github.com/NVIDIA/Megatron-LM/blob/90e0a0dd08159e1c95f4f9d99bb8687f327d36c3/megatron/mpu/layers.py#L201)\n", "![ColumnParallel](./Megatron-LM/pics/ColumnParallel.JPG)" ] }, { "cell_type": "code", "execution_count": null, "id": "turkish-capability", "metadata": {}, "outputs": [], "source": [ "## Below class is modified from the original Megatron repo in order to skip environment variable initialization\n", "\n", "import sys\n", "sys.path.append(\"./Megatron-LM\")\n", "from megatron.mpu import layers\n", "from torch.nn.parameter import Parameter\n", "import torch.nn.init as init\n", "import torch\n", "import random\n", "from megatron import *\n", "from megatron.mpu.tests import *\n", "from megatron.mpu.utils import *\n", "global world_size \n", "world_size = 16\n", "class myColumnParallelLinear(torch.nn.Module):\n", " \"\"\"Linear layer with column parallelism.\n", " The linear layer is defined as Y = XA + b. A is parallelized along\n", " its second dimension as A = [A_1, ..., A_p].\n", " Arguments:\n", " input_size: first dimension of matrix A.\n", " output_size: second dimension of matrix A.\n", " bias: If true, add bias\n", " gather_output: If true, call all-gether on output and make Y avaiable\n", " to all GPUs, otherwise, every GPU will have its output\n", " which is Y_i = XA_i\n", " init_method: method to initialize weights. Note that bias is always set\n", " to zero.\n", " stride: For the strided linear layers.\n", " keep_master_weight_for_test: This was added for testing and should be\n", " set to False. It returns the master weights\n", " used for initialization.\n", " skip_bias_add: This was added to enable performance optimations where bias\n", " can be fused with other elementwise operations. we skip \n", " adding bias but instead return it.\n", " \"\"\"\n", "\n", " def __init__(self, input_size, output_size, bias=True, gather_output=True,\n", " init_method=init.xavier_normal_, stride=1,\n", " keep_master_weight_for_test=False,\n", " skip_bias_add=False):\n", " super(myColumnParallelLinear, self).__init__()\n", "\n", " # Keep input parameters\n", " self.input_size = input_size\n", " self.output_size = output_size\n", " self.gather_output = gather_output\n", " # Divide the weight matrix along the last dimension.\n", " \n", " self.output_size_per_partition = divide(output_size, world_size)\n", " self.skip_bias_add = skip_bias_add\n", "\n", " # Parameters.\n", " # Note: torch.nn.functional.linear performs XA^T + b and as a result\n", " # we allocate the transpose.\n", " # Initialize weight. \n", " use_cpu_initialization=True # hard coded to use cpu\n", " params_dtype = torch.float # skipping need of args\n", " \n", " if use_cpu_initialization:\n", " self.weight = Parameter(torch.empty(self.output_size_per_partition,\n", " self.input_size,\n", " dtype=params_dtype))\n", " \n", " self.master_weight = m_initialize_affine_weight_cpu(\n", " self.weight, self.output_size, self.input_size,\n", " self.output_size_per_partition, 0, init_method,\n", " stride=stride, return_master_weight=keep_master_weight_for_test)\n", " \n", " else:\n", " self.weight = Parameter(torch.empty(\n", " self.output_size_per_partition, self.input_size,\n", " device=torch.cuda.current_device(), dtype=params_dtype))\n", " _initialize_affine_weight_gpu(self.weight, init_method,\n", " partition_dim=0, stride=stride)\n", " \n", " if bias:\n", " if use_cpu_initialization:\n", " self.bias = Parameter(torch.empty(\n", " self.output_size_per_partition, dtype=params_dtype))\n", " else:\n", " self.bias = Parameter(torch.empty(\n", " self.output_size_per_partition,\n", " device=torch.cuda.current_device(),\n", " dtype=params_dtype))\n", " # Always initialize bias to zero.\n", " with torch.no_grad():\n", " self.bias.zero_()\n", " else:\n", " self.register_parameter('bias', None)\n", "\n", " def forward(self, input_):\n", " # Set up backprop all-reduce.\n", " print(\"in Column parallel forward\")\n", " input_parallel = copy_to_tensor_model_parallel_region(input_)\n", " # Matrix multiply.\n", "\n", " bias = self.bias if not self.skip_bias_add else None\n", " output_parallel = F.linear(input_parallel, self.weight, bias)\n", " if self.gather_output:\n", " # All-gather across the partitions.\n", " output = gather_from_tensor_model_parallel_region(output_parallel)\n", " else:\n", " output = output_parallel \n", " output_bias = self.bias if self.skip_bias_add else None\n", " return output, output_bias" ] }, { "cell_type": "code", "execution_count": null, "id": "experimental-ocean", "metadata": {}, "outputs": [], "source": [ "def get_weight_list(master_weight,tensor_model_parallel_gp):\n", " my_weight_list=[]\n", " a,b=master_weight.size()\n", " print(\"A = [\")\n", " tensor_model_parallel_gp=list(itertools.chain(*tensor_model_parallel_gp))\n", " cnt=0\n", " for gp in tensor_model_parallel_gp :\n", " if which_model_parallel=='col': \n", " temp=master_weight[gp::world_size].T\n", " if cnt < world_size -1 : \n", " print(\"A{}=\".format(str(cnt)), temp.size(), end = ',')\n", " else:\n", " print(\"A{}=\".format(str(cnt)), temp.size())\n", " elif which_model_parallel =='row':\n", " temp=master_weight.T\n", " temp=temp[gp::world_size]\n", " if cnt < world_size -1 :\n", " print(\"A{}=\".format(str(cnt)), temp.size(),',')\n", " else:\n", " print(\"A{}=\".format(str(cnt)), temp.size())\n", "\n", " else:\n", " print(\"set which_model_parallel to **col** or **row**\")\n", " cnt+=1 \n", " my_weight_list.append(temp)\n", " \n", " print(\" ]\")\n", " print(len(my_weight_list))\n", " return my_weight_list\n", "def m_initialize_affine_weight_cpu(weight, output_size, input_size,\n", " per_partition_size, partition_dim,\n", " init_method, stride=1,\n", " return_master_weight=False):\n", " \"\"\"Initialize affine weight for model parallel.\n", " Build the master weight on all processes and scatter\n", " the relevant chunk.\"\"\"\n", " params_dtype = torch.float\n", " # Initialize master weight\n", " master_weight = torch.empty(output_size, input_size,\n", " dtype=torch.float,\n", " requires_grad=False) \n", " \n", " master_weight = master_weight.to(dtype=params_dtype)\n", " # Split and copy\n", " per_partition_per_stride_size = divide(per_partition_size, stride)\n", " print(\"per_partition_per_stride_size \",per_partition_per_stride_size)\n", " weight_list = torch.split(master_weight, per_partition_per_stride_size,\n", " dim=partition_dim)\n", " ######## tensor_model_parallel_gp below is hard-coded for tensor_model_parallel_size= 2 , pipeline_model_parallel_size= 4 ########\n", " ######## if you use other model parallel configuration , please copy and replace it in tensor_model_parallel_gp ########\n", " tensor_model_parallel_gp=[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]] \n", " my_weight_list = get_weight_list(master_weight,tensor_model_parallel_gp)\n", " \n", " with torch.no_grad():\n", " torch.cat(my_weight_list, dim=partition_dim, out=weight)\n", " if return_master_weight:\n", " return master_weight\n", " return None" ] }, { "cell_type": "markdown", "id": "wired-graham", "metadata": {}, "source": [ "Peek inside Column Parallel Class, indeed Column Parallel can partition input matrix A into [A0, A1, A2 ...An] each of the partition matrix Ai, where i= 0, 1, 2 ...n, is sliced **column-wised**. Each column-wised partitioned matrix Ai will then be sent to the corresponding gpu based on the groupped gpu affinity." ] }, { "cell_type": "code", "execution_count": null, "id": "circular-hearing", "metadata": {}, "outputs": [], "source": [ "tensor_model_parallel_size= 2 \n", "pipeline_model_parallel_size= 4 \n", "input_size = 1024 # 1024 rows\n", "output_size = 512 # 256 columns\n", "which_model_parallel='col'\n", "print(\"this is how A is sliced column-wised ...\\n\")\n", "testCol=myColumnParallelLinear(input_size, output_size, bias=True, gather_output=True,\n", " init_method=init.xavier_normal_, stride=1,\n", " keep_master_weight_for_test=False,\n", " skip_bias_add=False)" ] }, { "cell_type": "markdown", "id": "instrumental-nickel", "metadata": {}, "source": [ "Below is the expected outputs :\n", "\n", " this is how A is sliced column-wised ...\n", " per_partition_per_stride_size 32\n", " A = [\n", " A0= torch.Size([1024, 32]),A1= torch.Size([1024, 32]),A2= torch.Size([1024, 32]),A3= torch.Size([1024, 32]),A4= torch.Size([1024, 32]),A5= torch.Size([1024, 32]),A6= torch.Size([1024, 32]),A7= torch.Size([1024, 32]),A8= torch.Size([1024, 32]),A9= torch.Size([1024, 32]),A10= torch.Size([1024, 32]),A11= torch.Size([1024, 32]),A12= torch.Size([1024, 32]),A13= torch.Size([1024, 32]),A14= torch.Size([1024, 32]),A15= torch.Size([1024, 32])\n", " ]\n", " 16" ] }, { "cell_type": "code", "execution_count": null, "id": "saved-quantity", "metadata": {}, "outputs": [], "source": [ "per_partition_per_stride_size=32\n", "assert 16* per_partition_per_stride_size == 512 " ] }, { "cell_type": "code", "execution_count": null, "id": "complete-camel", "metadata": {}, "outputs": [], "source": [ "type(testCol)" ] }, { "cell_type": "markdown", "id": "divided-mexican", "metadata": {}, "source": [ "Below is the expected outputs :\n", " \n", " __main__.myColumnParallelLinear" ] }, { "cell_type": "code", "execution_count": null, "id": "corporate-recipient", "metadata": {}, "outputs": [], "source": [ "testCol.input_size, testCol.output_size" ] }, { "cell_type": "markdown", "id": "sound-haven", "metadata": {}, "source": [ "Below is the expected outputs :\n", " \n", " (1024, 512)" ] }, { "cell_type": "markdown", "id": "coupled-stevens", "metadata": {}, "source": [ "----------------------------------------------------------------------\n", "## Megatron-LM's Row Parallel \n", "[RowParallel reference](https://github.com/NVIDIA/Megatron-LM/blob/90e0a0dd08159e1c95f4f9d99bb8687f327d36c3/megatron/mpu/layers.py#L294)\n", "![RowParallel](./Megatron-LM/pics/RowParallel.JPG)" ] }, { "cell_type": "code", "execution_count": null, "id": "prerequisite-simulation", "metadata": {}, "outputs": [], "source": [ "## Below class is modified from the original Megatron repo in order to skip environment variable initialization\n", "\n", "class myRowParallelLinear(torch.nn.Module):\n", " \"\"\"Linear layer with row parallelism.\n", " The linear layer is defined as Y = XA + b. A is parallelized along\n", " its first dimension and X along its second dimension as:\n", " - -\n", " | A_1 |\n", " | . |\n", " A = | . | X = [X_1, ..., X_p]\n", " | . |\n", " | A_p |\n", " - -\n", " Arguments:\n", " input_size: first dimension of matrix A.\n", " output_size: second dimension of matrix A.\n", " bias: If true, add bias. Note that bias is not parallelized.\n", " input_is_parallel: If true, we assume that the input is already\n", " split across the GPUs and we do not split\n", " again.\n", " init_method: method to initialize weights. Note that bias is always set\n", " to zero.\n", " stride: For the strided linear layers.\n", " keep_master_weight_for_test: This was added for testing and should be\n", " set to False. It returns the master weights\n", " used for initialization.\n", " skip_bias_add: This was added to enable performance optimations where bias\n", " can be fused with other elementwise operations. we skip \n", " adding bias but instead return it.\n", " \"\"\"\n", "\n", " def __init__(self, input_size, output_size, bias=True,\n", " input_is_parallel=False,\n", " init_method=init.xavier_normal_, stride=1,\n", " keep_master_weight_for_test=False,\n", " skip_bias_add=False):\n", " super(myRowParallelLinear, self).__init__()\n", "\n", " # Keep input parameters\n", " self.input_size = input_size\n", " self.output_size = output_size\n", " self.input_is_parallel = input_is_parallel\n", " # Divide the weight matrix along the last dimension.\n", " self.input_size_per_partition = divide(input_size, world_size)\n", " self.skip_bias_add = skip_bias_add\n", " print(\"input_size_per_partition \", self.input_size_per_partition)\n", " \n", "\n", " # Parameters.\n", " # Note: torch.nn.functional.linear performs XA^T + b and as a result\n", " # we allocate the transpose.\n", " # Initialize weight.\n", " use_cpu_initialization=True # hard coded to use cpu\n", " params_dtype = torch.float # skipping need of args\n", " if use_cpu_initialization:\n", " self.weight = Parameter(torch.empty(self.output_size,\n", " self.input_size_per_partition,\n", " dtype=params_dtype))\n", " self.master_weight = m_initialize_affine_weight_cpu(\n", " self.weight, self.output_size, self.input_size,\n", " self.input_size_per_partition, 1, init_method,\n", " stride=stride, return_master_weight=keep_master_weight_for_test)\n", " else:\n", " self.weight = Parameter(torch.empty(\n", " self.output_size, self.input_size_per_partition,\n", " device=torch.cuda.current_device(), dtype=params_dtype))\n", " _initialize_affine_weight_gpu(self.weight, init_method,\n", " partition_dim=1, stride=stride)\n", " if bias:\n", " if use_cpu_initialization:\n", " self.bias = Parameter(torch.empty(self.output_size,\n", " dtype=params_dtype))\n", " else:\n", " self.bias = Parameter(torch.empty(\n", " self.output_size, device=torch.cuda.current_device(),\n", " dtype=params_dtype))\n", " # Always initialize bias to zero.\n", " with torch.no_grad():\n", " self.bias.zero_()\n", " else:\n", " self.register_parameter('bias', None)\n", "\n", "\n", "\n", " def forward(self, input_):\n", " # Set up backprop all-reduce.\n", " if self.input_is_parallel:\n", " input_parallel = input_\n", " else:\n", " input_parallel = scatter_to_tensor_model_parallel_region(input_)\n", " # Matrix multiply.\n", " output_parallel = F.linear(input_parallel, self.weight)\n", " # All-reduce across all the partitions.\n", " output_ = reduce_from_tensor_model_parallel_region(output_parallel)\n", " if not self.skip_bias_add:\n", " output = output_ + self.bias if self.bias is not None else output_\n", " output_bias = None\n", " else:\n", " output = output_\n", " output_bias = self.bias\n", " return output, output_bias" ] }, { "cell_type": "markdown", "id": "satisfactory-irish", "metadata": {}, "source": [ "Peek inside Row Parallel Class, indeed Row Parallel can partition input matrix A into [A0, A1, A2 ...An] each of the partition matrix Ai, where i= 0, 1, 2 ...n, Ai was chopped in a **row-wised** manner. Each row-wised partitioned matrix Ai will then be sent to the corresponding gpu based on the groupped gpu affinity." ] }, { "cell_type": "code", "execution_count": null, "id": "green-compiler", "metadata": {}, "outputs": [], "source": [ "input_size = 1024 # first dimension of the matrix\n", "output_size = 512 # 2nd dimension of the matrix\n", "print(\"this is how A is sliced Row-wised ...\\n\")\n", "which_model_parallel='row'\n", "testRow=myRowParallelLinear(input_size,output_size, bias=True,\n", " input_is_parallel=False,\n", " init_method=init.xavier_normal_, stride=1,\n", " keep_master_weight_for_test=False,\n", " skip_bias_add=False)" ] }, { "cell_type": "markdown", "id": "growing-standing", "metadata": {}, "source": [ "Below is the expected outputs :\n", "\n", " this is how A is sliced Row-wised ...\n", "\n", " input_size_per_partition 64\n", " per_partition_per_stride_size 64\n", " A = [\n", " A0= torch.Size([64, 512]) ,\n", " A1= torch.Size([64, 512]) ,\n", " A2= torch.Size([64, 512]) ,\n", " A3= torch.Size([64, 512]) ,\n", " A4= torch.Size([64, 512]) ,\n", " A5= torch.Size([64, 512]) ,\n", " A6= torch.Size([64, 512]) ,\n", " A7= torch.Size([64, 512]) ,\n", " A8= torch.Size([64, 512]) ,\n", " A9= torch.Size([64, 512]) ,\n", " A10= torch.Size([64, 512]) ,\n", " A11= torch.Size([64, 512]) ,\n", " A12= torch.Size([64, 512]) ,\n", " A13= torch.Size([64, 512]) ,\n", " A14= torch.Size([64, 512]) ,\n", " A15= torch.Size([64, 512])\n", " ]\n", " 16" ] }, { "cell_type": "code", "execution_count": null, "id": "handmade-judge", "metadata": {}, "outputs": [], "source": [ "per_partition_per_stride_size=64\n", "assert 16* per_partition_per_stride_size == 1024 " ] }, { "cell_type": "code", "execution_count": null, "id": "found-investing", "metadata": {}, "outputs": [], "source": [ "testRow.input_size, testRow.output_size" ] }, { "cell_type": "markdown", "id": "prime-combine", "metadata": {}, "source": [ "Below is the expected outputs :\n", " \n", " (1024, 512)" ] }, { "cell_type": "markdown", "id": "disciplinary-western", "metadata": {}, "source": [ "---\n", "\n", "## Links and Resources\n", "Don't forget to check out additional resources such as [Efficient Large-Scale Language Model Training on GPU Clusters](https://arxiv.org/pdf/2104.04473.pdf ) and [Pushing Forward the Frontiers of Natural Language Processing](https://blogs.nvidia.com/blog/2021/09/16/nlp-frontiers-ai-hardware-summit/)." ] }, { "cell_type": "markdown", "id": "blessed-crystal", "metadata": {}, "source": [ "-----\n", "##

HOME      NEXT

\n" ] }, { "cell_type": "markdown", "id": "lucky-average", "metadata": {}, "source": [ "-----\n", "\n", "\n", "## Licensing \n", "\n", "This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }