{ "cells": [ { "cell_type": "markdown", "id": "alternate-collar", "metadata": {}, "source": [ "# \n", "\n", "# 5 Monitor GPT training performance with varying config\n", "---\n", "\n", "## Learning Objectives\n", "- **The goal of this lab is to monitor the performance of your training runs with different GPT training configurations **\n", " - motivation : why should we care ? \n", " \n", " Answer : bad config result in very low / inconsistent gpus utilizations which in turn, slow down training and therefore longer experiments per run, it's a lose-lose-lose situation on all sides.\n", " ![see example](./Megatron-LM/pics/naive_run.JPG)\n", " \n", " - example : naive run vs. improved run \n", " starts with multiGPUs --> multinode ( if we get at least 2 nodes per person / team ) \n", " - exercise : beat the record !\n", "\n", "it is possible to obtain more than 90% GPU utilizations overall with high tensorcore ops sustained throughout forward and backward training throughout all gpus used in training. \n" ] }, { "cell_type": "code", "execution_count": null, "id": "fifty-swimming", "metadata": {}, "outputs": [], "source": [ "!rm -fr ./Megatron-LM/sv_ckpt/*" ] }, { "cell_type": "markdown", "id": "copyrighted-belarus", "metadata": {}, "source": [ "## Let's verify the environment is ready " ] }, { "cell_type": "code", "execution_count": null, "id": "chronic-bradley", "metadata": {}, "outputs": [], "source": [ "!nvidia-smi" ] }, { "cell_type": "code", "execution_count": null, "id": "personalized-walker", "metadata": {}, "outputs": [], "source": [ "!nvidia-smi nvlink --status " ] }, { "cell_type": "markdown", "id": "minimal-extreme", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "id": "prostate-trouble", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 2, "id": "industrial-index", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting data...\n", "using world size: 8, data-parallel-size: 8, tensor-model-parallel size: 1, pipeline-model-parallel size: 1 \n", "using torch.float32 for parameters ...\n", "------------------------ arguments ------------------------\n", " accumulate_allreduce_grads_in_fp32 .............. False\n", " adam_beta1 ...................................... 0.9\n", " adam_beta2 ...................................... 0.999\n", " adam_eps ........................................ 1e-08\n", " adlr_autoresume ................................. False\n", " adlr_autoresume_interval ........................ 1000\n", " apply_query_key_layer_scaling ................... True\n", " apply_residual_connection_post_layernorm ........ False\n", " attention_dropout ............................... 0.1\n", " attention_softmax_in_fp32 ....................... False\n", " bert_binary_head ................................ True\n", " bert_load ....................................... None\n", " bf16 ............................................ False\n", " bias_dropout_fusion ............................. True\n", " bias_gelu_fusion ................................ True\n", " biencoder_projection_dim ........................ 0\n", " biencoder_shared_query_context_model ............ False\n", " block_data_path ................................. None\n", " checkpoint_activations .......................... True\n", " checkpoint_num_layers ........................... 1\n", " clip_grad ....................................... 1.0\n", " consumed_train_samples .......................... 0\n", " consumed_valid_samples .......................... 0\n", " data_impl ....................................... mmap\n", " data_parallel_size .............................. 8\n", " data_path ....................................... ['../dataset/EN/NVblogs_text_document']\n", " dataloader_type ................................. single\n", " DDP_impl ........................................ local\n", " decoder_seq_length .............................. None\n", " distribute_checkpointed_activations ............. False\n", " distributed_backend ............................. nccl\n", " embedding_path .................................. None\n", " encoder_seq_length .............................. 512\n", " eod_mask_loss ................................... False\n", " eval_interval ................................... 100\n", " eval_iters ...................................... 10\n", " evidence_data_path .............................. None\n", " exit_duration_in_mins ........................... None\n", " exit_interval ................................... None\n", " ffn_hidden_size ................................. 4096\n", " finetune ........................................ False\n", " fp16 ............................................ False\n", " fp16_lm_cross_entropy ........................... False\n", " fp32_residual_connection ........................ False\n", " global_batch_size ............................... 8\n", " hidden_dropout .................................. 0.1\n", " hidden_size ..................................... 1024\n", " hysteresis ...................................... 2\n", " ict_head_size ................................... None\n", " ict_load ........................................ None\n", " img_dim ......................................... 224\n", " indexer_batch_size .............................. 128\n", " indexer_log_interval ............................ 1000\n", " init_method_std ................................. 0.02\n", " init_method_xavier_uniform ...................... False\n", " initial_loss_scale .............................. 4294967296\n", " kv_channels ..................................... 64\n", " layernorm_epsilon ............................... 1e-05\n", " lazy_mpu_init ................................... None\n", " load ............................................ ./Megatron-LM/sv_ckpt/\n", " local_rank ...................................... 0\n", " log_batch_size_to_tensorboard ................... False\n", " log_interval .................................... 10\n", " log_learning_rate_to_tensorboard ................ True\n", " log_loss_scale_to_tensorboard ................... True\n", " log_num_zeros_in_grad ........................... False\n", " log_params_norm ................................. False\n", " log_timers_to_tensorboard ....................... False\n", " log_validation_ppl_to_tensorboard ............... False\n", " loss_scale ...................................... None\n", " loss_scale_window ............................... 1000\n", " lr .............................................. 0.00015\n", " lr_decay_iters .................................. None\n", " lr_decay_samples ................................ None\n", " lr_decay_style .................................. cosine\n", " lr_warmup_fraction .............................. 0.01\n", " lr_warmup_iters ................................. 0\n", " lr_warmup_samples ............................... 0\n", " make_vocab_size_divisible_by .................... 128\n", " mask_prob ....................................... 0.15\n", " masked_softmax_fusion ........................... True\n", " max_position_embeddings ......................... 512\n", " merge_file ...................................... ../dataset/EN/50k/gpt2-merges.txt\n", " micro_batch_size ................................ 1\n", " min_loss_scale .................................. 1.0\n", " min_lr .......................................... 1e-05\n", " mmap_warmup ..................................... False\n", " no_load_optim ................................... None\n", " no_load_rng ..................................... None\n", " no_save_optim ................................... None\n", " no_save_rng ..................................... None\n", " num_attention_heads ............................. 16\n", " num_channels .................................... 3\n", " num_classes ..................................... 1000\n", " num_layers ...................................... 16\n", " num_layers_per_virtual_pipeline_stage ........... None\n", " num_workers ..................................... 2\n", " onnx_safe ....................................... None\n", " openai_gelu ..................................... False\n", " optimizer ....................................... adam\n", " override_lr_scheduler ........................... False\n", " params_dtype .................................... torch.float32\n", " patch_dim ....................................... 16\n", " pipeline_model_parallel_size .................... 1\n", " query_in_block_prob ............................. 0.1\n", " rampup_batch_size ............................... None\n", " rank ............................................ 0\n", " reset_attention_mask ............................ False\n", " reset_position_ids .............................. False\n", " retriever_report_topk_accuracies ................ []\n", " retriever_score_scaling ......................... False\n", " retriever_seq_length ............................ 256\n", " sample_rate ..................................... 1.0\n", " save ............................................ ./Megatron-LM/sv_ckpt/\n", " save_interval ................................... 100\n", " scatter_gather_tensors_in_pipeline .............. True\n", " seed ............................................ 1234\n", " seq_length ...................................... 512\n", " sgd_momentum .................................... 0.9\n", " short_seq_prob .................................. 0.1\n", " split ........................................... 949,50,1\n", " tensor_model_parallel_size ...................... 1\n", " tensorboard_dir ................................. None\n", " tensorboard_log_interval ........................ 1\n", " tensorboard_queue_size .......................... 1000\n", " titles_data_path ................................ None\n", " tokenizer_type .................................. GPT2BPETokenizer\n", " train_iters ..................................... None\n", " train_samples ................................... 100\n", " use_checkpoint_lr_scheduler ..................... False\n", " use_contiguous_buffers_in_ddp ................... False\n", " use_cpu_initialization .......................... None\n", " use_one_sent_docs ............................... False\n", " virtual_pipeline_model_parallel_size ............ None\n", " vocab_extra_ids ................................. 0\n", " vocab_file ...................................... ../dataset/EN/50k/gpt2-vocab.json\n", " weight_decay .................................... 0.01\n", " world_size ...................................... 8\n", "-------------------- end of arguments ---------------------\n", "setting number of micro-batches to constant 1\n", "> building GPT2BPETokenizer tokenizer ...\n", " > padded vocab (size: 50257) with 47 dummy tokens (new size: 50304)\n", "> initializing torch distributed ...\n", "> initializing tensor model parallel with size 1\n", "> initializing pipeline model parallel with size 1\n", "> setting random seeds to 1234 ...\n", "> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234\n", "> compiling dataset index builder ...\n", "make: Entering directory '/home/zcharpy/bootcamp/jupyter_notebook/Megatron-LM/megatron/data'\n", "make: Nothing to be done for 'default'.\n", "make: Leaving directory '/home/zcharpy/bootcamp/jupyter_notebook/Megatron-LM/megatron/data'\n", ">>> done with dataset index builder. Compilation time: 0.573 seconds\n", "WARNING: constraints for invoking optimized fused softmax kernel are not met. We default back to unfused kernel invocations.\n", "> compiling and loading fused kernels ...\n", "Detected CUDA files, patching ldflags\n", "Emitting ninja build file /home/zcharpy/bootcamp/jupyter_notebook/Megatron-LM/megatron/fused_kernels/build/build.ninja...\n", "Building extension module scaled_upper_triang_masked_softmax_cuda...\n", "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", "ninja: no work to do.\n", "Loading extension module scaled_upper_triang_masked_softmax_cuda...\n", "Detected CUDA files, patching ldflags\n", "Emitting ninja build file /home/zcharpy/bootcamp/jupyter_notebook/Megatron-LM/megatron/fused_kernels/build/build.ninja...\n", "Building extension module scaled_masked_softmax_cuda...\n", "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", "ninja: no work to do.\n", "Loading extension module scaled_masked_softmax_cuda...\n", "Detected CUDA files, patching ldflags\n", "Emitting ninja build file /home/zcharpy/bootcamp/jupyter_notebook/Megatron-LM/megatron/fused_kernels/build/build.ninja...\n", "Building extension module fused_mix_prec_layer_norm_cuda...\n", "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", "ninja: no work to do.\n", "Loading extension module fused_mix_prec_layer_norm_cuda...\n", ">>> done with compiling and loading fused kernels. Compilation time: 31.516 seconds\n", "time to initialize megatron (seconds): 13.548\n", "[after megatron is initialized] datetime: 2021-08-26 00:28:13 \n", "building GPT model ...\n", " > number of parameters on (tensor, pipeline) model parallel rank (0, 0): 253577216\n", "setting training iterations to 12\n", "> learning rate decay style: cosine\n", "WARNING: could not find the metadata file ./Megatron-LM/sv_ckpt/latest_checkpointed_iteration.txt \n", " will not load any checkpoints and will start from random\n", "time (ms) | load-checkpoint: 30.87\n", "[after model, optimizer, and learning rate scheduler are built] datetime: 2021-08-26 00:28:14 \n", "> building train, validation, and test datasets ...\n", " > datasets target sizes (minimum size):\n", " train: 100\n", " validation: 80\n", " test: 80\n", "> building train, validation, and test datasets for GPT ...\n", " > building dataset index ...\n", " reading sizes...\n", " reading pointers...\n", " reading document index...\n", " creating numpy buffer of mmap...\n", " creating memory view of numpy buffer...\n", " > finished creating indexed dataset in 0.003097 seconds\n", " number of documents: 74\n", " > dataset split:\n", " train:\n", " document indices in [0, 70) total of 70 documents\n", " validation:\n", " document indices in [70, 74) total of 4 documents\n", " test:\n", " document indices in [74, 74) total of 0 documents\n", " > loading doc-idx mapping from ../dataset/EN/NVblogs_text_document_train_indexmap_100ns_512sl_1234s_doc_idx.npy\n", " > loading sample-idx mapping from ../dataset/EN/NVblogs_text_document_train_indexmap_100ns_512sl_1234s_sample_idx.npy\n", " > loading shuffle-idx mapping from ../dataset/EN/NVblogs_text_document_train_indexmap_100ns_512sl_1234s_shuffle_idx.npy\n", " loaded indexed file in 0.018 seconds\n", " total number of samples: 142\n", " total number of epochs: 1\n", " > loading doc-idx mapping from ../dataset/EN/NVblogs_text_document_valid_indexmap_80ns_512sl_1234s_doc_idx.npy\n", " > loading sample-idx mapping from ../dataset/EN/NVblogs_text_document_valid_indexmap_80ns_512sl_1234s_sample_idx.npy\n", " > loading shuffle-idx mapping from ../dataset/EN/NVblogs_text_document_valid_indexmap_80ns_512sl_1234s_shuffle_idx.npy\n", " loaded indexed file in 0.022 seconds\n", " total number of samples: 86\n", " total number of epochs: 11\n", "> finished creating GPT datasets ...\n", "[after dataloaders are built] datetime: 2021-08-26 00:28:24 \n", "done with setup ...\n", "training ...\n", "time (ms) | model-and-optimizer-setup: 548.83 | train/valid/test-data-iterators-setup: 10068.67\n", "[before the start of training step] datetime: 2021-08-26 00:28:24 \n", " iteration 10/ 12 | consumed samples: 80 | elapsed time per iteration (ms): 2141.9 | learning rate: 2.363E-05 | global batch size: 8 | lm loss: 9.601698E+00 | loss scale: 1.0 | grad norm: 1.856 | number of skipped iterations: 0 | number of nan iterations: 0 |\n", "time (ms) | forward-compute: 1252.78 | backward-compute: 643.07 | backward-params-all-reduce: 146.14 | backward-embedding-all-reduce: 0.06 | optimizer: 94.92 | batch-generator: 12.73\n", "[Rank 0] (after 10 iterations) memory (MB) | allocated: 3869.28369140625 | max allocated: 5229.60595703125 | reserved: 7306.0 | max reserved: 7306.0\n", "[after training is done] datetime: 2021-08-26 00:28:47 \n", "------------------------------------------------------------------------------------------------------------------saving checkpoint at iteration 12 to ./Megatron-LM/sv_ckpt/\n", "\n", " validation loss at the end of training for val data | lm loss value: 8.891883E+00 | lm loss PPL: 7.272700E+03 | \n", "------------------------------------------------------------------------------------------------------------------\n", " successfully saved checkpoint at iteration 12 to ./Megatron-LM/sv_ckpt/\n", "*****************************************\n", "Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n", "*****************************************\n", "Processing events...\n", "Capturing symbol files...\n", "Saving temporary \"/tmp/nsys-report-5e6b-aa12-9711-df33.qdstrm\" file to disk...\n", "Creating final output files...\n", "\n", "Processing [==============================================================100%]\n", "Saved report file to \"/tmp/nsys-report-5e6b-aa12-9711-df33.qdrep\"\n", "Exporting 1665820 events: [===============================================100%]\n", "\n", "Exported successfully to\n", "/tmp/nsys-report-5e6b-aa12-9711-df33.sqlite\n", "\n", "\n", "CUDA API Statistics:\n", "\n", " Time(%) Total Time (ns) Num Calls Average Minimum Maximum Name \n", " ------- --------------- --------- ----------- ------- ----------- -------------------------------\n", " 56.2 285320355235 1696 168231341.5 3530 29703799378 cudaDeviceSynchronize \n", " 29.6 149955233137 4392 34142812.6 4864 17084387260 cudaMalloc \n", " 7.9 39849705340 2960 13462738.3 1275 2244485039 cudaFree \n", " 1.7 8434557988 132848 63490.3 4159 108060983 cudaLaunchKernel \n", " 1.6 8056133957 1136 7091667.2 5385 73711234 cudaHostAlloc \n", " 1.4 7314617708 23844 306769.7 6186 758561184 cudaMemcpyAsync \n", " 0.4 1978894053 720 2748464.0 352582 35632270 cudaIpcOpenMemHandle \n", " 0.3 1558283230 1064 1464551.9 5957 26847462 cudaFreeHost \n", " 0.3 1519497814 720 2110413.6 240178 21756247 cudaIpcCloseMemHandle \n", " 0.2 869040030 15104 57537.1 7662 32041419 cuLaunchKernel \n", " 0.1 760874918 14944 50915.1 5032 103123066 cudaMemsetAsync \n", " 0.1 416824863 63658 6547.9 361 79176267 cudaStreamIsCapturing_v10000 \n", " 0.1 289499706 16483 17563.5 1293 40910490 cudaEventQuery \n", " 0.1 283699184 21864 12975.6 643 60138070 cudaEventRecord \n", " 0.0 163990222 44832 3657.9 279 24748108 cudaStreamGetCaptureInfo_v10010\n", " 0.0 149197647 512 291401.7 2613 37658462 cudaStreamCreateWithPriority \n", " 0.0 138131763 4504 30668.7 523 32220166 cudaEventDestroy \n", " 0.0 113211871 2732 41439.2 2572 30272284 cudaStreamSynchronize \n", " 0.0 75811639 2848 26619.3 3170 29446146 cudaMemset \n", " 0.0 68895475 40 1722386.9 101401 32335797 cuModuleLoadData \n", " 0.0 36296853 1824 19899.6 6299 6783812 cudaMemcpy \n", " 0.0 13397443 40 334936.1 21288 4248804 cuModuleUnload \n", " 0.0 12237981 4520 2707.5 381 2838236 cudaEventCreateWithFlags \n", " 0.0 11913924 2784 4279.4 1276 3291031 cudaStreamWaitEvent \n", " 0.0 100758 24 4198.3 1730 10252 cuInit \n", "\n", "\n", "\n", "CUDA Kernel Statistics:\n", "\n", " Time(%) Total Time (ns) Instances Average Minimum Maximum Name \n", " ------- --------------- --------- ------------ ------- ----------- ----------------------------------------------------------------------------------------------------\n", " 88.2 260838216539 88 2964070642.5 9664 29673260794 ncclKernel_AllReduce_RING_LL_Sum_uint8_t(ncclWorkElem) \n", " 6.4 18834798439 280 67267137.3 13952 423636076 ncclKernel_AllReduce_RING_LL_Sum_float(ncclWorkElem) \n", " 1.3 3737595881 13056 286274.2 100127 445790 volta_sgemm_128x32_tn \n", " 0.6 1922114218 6240 308031.1 88128 4093366 volta_sgemm_128x32_nt \n", " 0.6 1914679497 4528 422853.2 268063 4114387 volta_sgemm_128x64_tn \n", " 0.6 1628497447 6144 265054.9 91936 432223 volta_sgemm_128x32_nn \n", " 0.3 1019927165 16 63745447.8 12160 185146385 ncclKernel_AllReduce_RING_LL_Sum_int64_t(ncclWorkElem) \n", " 0.3 911914712 1632 558771.3 5824 807487 void multi_tensor_apply_kernel, AdamFunctor, float, float, float, floa…\n", " 0.2 624933625 5888 106136.8 96128 125600 volta_sgemm_64x64_nn \n", " 0.2 516397176 9608 53746.6 2495 2498036 void at::native::vectorized_elementwise_kernel<4, at::native::MulScalarFunctor, at::d…\n", " 0.1 370199942 96 3856249.4 3595413 4220668 volta_sgemm_64x32_sliced1x4_nn \n", " 0.1 365762037 5888 62119.9 53312 81952 volta_sgemm_64x64_tn \n", " 0.1 317760917 3072 103437.8 94048 120768 volta_sgemm_64x64_nt \n", " 0.1 291143645 5888 49447.0 45408 58304 void at::native::unrolled_elementwise_kernel(float*,…\n", " 0.1 251234285 1440 174468.3 3585 228607 void multi_tensor_apply_kernel, ScaleFunctor, float>(int, int v…\n", " 0.1 211542704 4352 48608.2 45728 53376 void (anonymous namespace)::softmax_warp_forward(float*, float const…\n", " 0.1 165985735 3168 52394.5 7648 57663 void at::native::(anonymous namespace)::fused_dropout_kernel_vec(…\n", " 0.1 154413148 7088 21785.2 7424 259423 void at::native::unrolled_elementwise_kernel, at::detail::Array, L2NormFunctor, float*, float*, bool, i…\n", " 0.0 118285804 6352 18621.8 2560 733214 void at::native::vectorized_elementwise_kernel<4, at::native::AddFunctor, at::detail::Array<…\n", " 0.0 105684444 4352 24284.1 21791 37184 kernel_1 \n", " 0.0 98824724 1904 51903.7 2496 70079 void at::native::vectorized_elementwise_kernel<4, at::native::MulFunctor, at::detail::Array<…\n", " 0.0 97740010 1536 63632.8 62784 68160 void (anonymous namespace)::softmax_warp_backward(float*, float cons…\n", " 0.0 82467209 8880 9286.8 6656 16928 void cuApplyLayerNorm(float*, float*, float*, float const*, int, int, float, f…\n", " 0.0 76246947 1632 46719.9 5920 52928 void at::native::vectorized_elementwise_kernel<4, at::native::(anonymous namespace)::masked_scale_k…\n", " 0.0 66689595 5952 11204.6 8864 18848 kernel_2 \n", " 0.0 65696173 4560 14407.1 2335 344991 void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor, at::detail::Array…\n", " 0.0 54335809 1536 35374.9 33120 40640 kernel_4 \n", " 0.0 51338187 4352 11796.5 8576 16288 void at::native::unrolled_elementwise_kernel, at::detail::Array(float*,…\n", " 0.0 33892087 3168 10698.3 8480 18304 void cuComputePartGradGammaBeta(float const*, float const*, int, int, float co…\n", " 0.0 29677485 3168 9367.9 8032 18464 void cuComputeGradInput(float const*, float const*, int, int, float const*, fl…\n", " 0.0 28793796 3072 9373.0 6816 16351 kernel_3 \n", " 0.0 25638427 176 145672.9 141248 151744 void at::native::reduce_kernel<512, 1, at::native::ReduceOp, unsig…\n", " 0.0 24934651 192 129868.0 6208 254687 void at::native::unrolled_elementwise_kernel, at::detail::Array(float const*, float const*, int, int, int, float*, float*)\n", " 0.0 4617682 192 24050.4 18080 34848 void at::native::(anonymous namespace)::embedding_backward_feature_kernel(long*…\n", " 0.0 4383059 352 12451.9 11072 14975 void at::native::(anonymous namespace)::indexSelectLargeIndex::Policy600…\n", " 0.0 1350339 352 3836.2 2911 5216 void cub::DeviceSelectSweepKernel, cub…\n", " 0.0 1254812 176 7129.6 6464 9217 void at::native::triu_tril_kernel(at::cuda::detail::TensorInfo, at::…\n", " 0.0 1248508 448 2786.8 2399 3840 void (anonymous namespace)::elementwise_kernel_with_index, at::detail::Array<…\n", " 0.0 942900 352 2678.7 2336 3552 void cub::DeviceCompactInitKernel, int*>(cub::ScanTileState, at::detail::A…\n", " 0.0 502590 176 2855.6 2655 3488 void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor, at::detail::Array<…\n", " 0.0 383140 96 3991.0 3744 4480 cleanup(float*, float*, float*, float*, bool, int) \n", " 0.0 319140 96 3324.4 2912 16032 void at::native::vectorized_elementwise_kernel<4, at::native::BUnaryFunctor" ] }, { "cell_type": "code", "execution_count": null, "id": "requested-clause", "metadata": {}, "outputs": [], "source": [ "!bash ./Megatron-LM/dlprof_2nd_run.sh" ] }, { "cell_type": "code", "execution_count": null, "id": "written-trace", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "sunrise-borough", "metadata": {}, "source": [ "----------------\n", "\n", "## **Challenge ** - the best profile\n", "\n", "with the exact same compute limitations ( i.e # of gpus you currently have ) \n", "\n", "task: modify the [profiling bash script](./Megatron-LM/dlprof_2nd_run.sh) and rerun \n", "Jump to ReRun Cell \n", "monitor the training runs to get an overall >80% gpu utils in **training** runs \n", "\n", "```\n", " TENSOR_MP_SIZE=1\n", " PIPELINE_MP_SIZE=1\n", "\n", " #GPT Config \n", " LAYERS= \n", " HIDDEN_SIZE=\n", " ATTN_HEADS=\n", " MICRO_BZ=\n", " GB_BZ=\n", " SEQ_LEN=\n", " MAX_POS_EM=\n", "``` \n", "" ] }, { "cell_type": "markdown", "id": "streaming-artist", "metadata": {}, "source": [ "-----\n", "\n", "\n", "## Licensing \n", "\n", "This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }