Tensor Parallel 구현 비교

TL;DR

이 글은 Tensor Parallel 이론편의 후속으로, 6개 주요 프레임워크의 실제 TP 구현을 비교합니다. 학습 프레임워크(Megatron-LM, nanotron, DeepSpeed, torchtitan)와 추론 프레임워크(vLLM, SGLang)는 동일한 이론적 기반 위에서 서로 다른 최적화 철학을 보여줍니다. Megatron-LM의 f/g 켤레 연산자 패러다임은 거의 모든 프레임워크에 영향을 주었으나, 각자의 요구사항에 맞게 변형되었습니다. 핵심 차이점: 학습 프레임워크는 그래디언트 통신 최적화에, 추론 프레임워크는 weight loading과 레이턴시 최소화에 집중합니다.

이 글은 구현 세부사항뿐 아니라 각 프레임워크의 API 사용법 (CLI, Python, Config)도 함께 다루어, 실제 적용을 위한 실용적 가이드를 제공합니다.


Related Work

이론적 기반:

프레임워크 문서:

  • Megatron-LM - NVIDIA의 대규모 Transformer 학습
  • nanotron - HuggingFace의 분산 학습 라이브러리
  • DeepSpeed - Microsoft의 분산 학습 최적화
  • torchtitan - Meta의 PyTorch-native 학습 플랫폼
  • vLLM - 고성능 LLM 서빙 엔진
  • SGLang - 구조화된 출력 지원 LLM 서빙

1. 서론: 왜 구현을 비교하는가?

이전 글에서 Megatron-LM이 제안한 Tensor Parallelism의 이론적 기반을 살펴보았습니다. 핵심 아이디어는 단순합니다:

  1. ColumnParallelLinear: 출력 차원을 분할하고, backward에서 all-reduce
  2. RowParallelLinear: 입력 차원을 분할하고, forward에서 all-reduce
  3. f/g 켤레 연산자: forward와 backward에서 상호 보완적인 통신 패턴

그러나 실제 구현에서는 다양한 트레이드오프가 존재합니다:

고려 사항 학습 최적화 추론 최적화
주요 병목 그래디언트 통신 weight loading, 레이턴시
메모리 관심 활성화 + optimizer 상태 KV cache
배치 크기 크게 (수천) 작게 (수십~수백)
통신 패턴 비동기 오버랩 중요 동기 단순성 선호

이 글에서는 6개 프레임워크의 실제 코드를 분석하여, 이론이 어떻게 다양한 형태로 구현되는지 살펴봅니다.


2. 학습 프레임워크

2.1 Megatron-LM: 원조의 정석

핵심 파일: megatron/core/tensor_parallel/

Megatron-LM은 TP의 원조답게 가장 정교한 구현을 제공합니다. 핵심은 f/g 켤레 연산자입니다.

Quick Start

CLI:

# 8-way Tensor Parallelism
torchrun --nproc_per_node=8 pretrain_gpt.py \
    --tensor-model-parallel-size 8 \
    --pipeline-model-parallel-size 1 \
    --num-layers 32 \
    --hidden-size 4096 \
    --num-attention-heads 32

Python API:

from megatron.core import parallel_state

# 분산 환경 초기화
parallel_state.initialize_model_parallel(
    tensor_model_parallel_size=8,
    pipeline_model_parallel_size=1,
    virtual_pipeline_model_parallel_size=None,
)

# TP 그룹 조회
tp_group = parallel_state.get_tensor_model_parallel_group()
tp_rank = parallel_state.get_tensor_model_parallel_rank()
tp_world_size = parallel_state.get_tensor_model_parallel_world_size()

주의사항:

  • CUDA_DEVICE_MAX_CONNECTIONS=1 환경 변수 필수 (비동기 통신 순서 보장)
  • --sequence-parallel 플래그로 Sequence Parallel 활성화 가능
  • TP 크기는 attention head 수의 약수여야 함

핵심 연산자 4종 세트

Megatron-LM은 4가지 핵심 autograd 연산자를 정의합니다:

연산자 Forward Backward 용도
_CopyToModelParallelRegion identity all-reduce ColumnParallel 입력
_ReduceFromModelParallelRegion all-reduce identity RowParallel 출력
_ScatterToModelParallelRegion split(last) gather(last) hidden dim 분할
_GatherFromModelParallelRegion gather(last) split(last) hidden dim 수집

파일: megatron/core/tensor_parallel/mappings.py (lines 197-273)

f 연산자: Copy to TP Region

# megatron/core/tensor_parallel/mappings.py:197-214
class _CopyToModelParallelRegion(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_, group):
        """Forward: identity (통신 없음)"""
        ctx.group = group
        return input_

    @staticmethod
    def backward(ctx, grad_output):
        """Backward: all-reduce gradients"""
        return _reduce(grad_output, ctx.group), None

용도: ColumnParallelLinear 시작 부분에서 사용. 입력 $X$는 모든 랭크에 복제되어 있으므로 forward에서는 그대로 전달하고, backward에서 각 랭크의 부분 그래디언트를 합산합니다.

g 연산자: Reduce from TP Region

# megatron/core/tensor_parallel/mappings.py:217-233
class _ReduceFromModelParallelRegion(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_, group):
        """Forward: all-reduce across TP ranks"""
        return _reduce(input_, group)

    @staticmethod
    def backward(ctx, grad_output):
        """Backward: identity (통신 없음)"""
        return grad_output, None

용도: RowParallelLinear 끝 부분에서 사용. 각 랭크가 계산한 $Y_i = X_i A_i$를 합산하여 최종 출력 $Y = \sum_i Y_i$를 생성합니다.

Scatter/Gather 연산자

Scatter와 Gather는 hidden dimension을 분할/수집할 때 사용됩니다:

# megatron/core/tensor_parallel/mappings.py:236-273
class _ScatterToModelParallelRegion(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_, group):
        """Forward: split along last dim"""
        ctx.group = group
        return _split_along_last_dim(input_, group)

    @staticmethod
    def backward(ctx, grad_output):
        """Backward: gather along last dim"""
        return _gather_along_last_dim(grad_output, ctx.group), None

class _GatherFromModelParallelRegion(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_, group):
        """Forward: gather along last dim"""
        ctx.group = group
        return _gather_along_last_dim(input_, group)

    @staticmethod
    def backward(ctx, grad_output):
        """Backward: split along last dim"""
        return _split_along_last_dim(grad_output, ctx.group), None

비동기 통신-계산 오버랩

Megatron-LM의 핵심 혁신 중 하나는 통신과 계산의 오버랩입니다:

# megatron/core/tensor_parallel/layers.py:494-627
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors
        tp_group = ctx.tp_group

        # 1. 입력 그래디언트 계산
        grad_input = grad_output.matmul(weight)

        # 2. 비동기 all-reduce 시작 (통신)
        if ctx.allreduce_dgrad:
            handle = torch.distributed.all_reduce(
                grad_input, group=tp_group, async_op=True
            )
            # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
            # all-reduce is scheduled before the weight gradient computation

        # 3. 가중치 그래디언트 계산 (계산) - 통신과 동시에!
        if ctx.gradient_accumulation_fusion:
            # CUDA 커널로 main_grad에 직접 누적 → 중간 텐서 할당 제거
            if weight.main_grad.dtype == torch.float32:
                fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
                    total_input, grad_output, weight.main_grad
                )
            # ...
        else:
            grad_weight = grad_output.t().matmul(total_input)

        # 4. all-reduce 완료 대기
        if ctx.allreduce_dgrad:
            handle.wait()

        return grad_input, grad_weight, ...

핵심 통찰:

  • async_op=True로 all-reduce를 비동기로 시작
  • CUDA_DEVICE_MAX_CONNECTIONS=1 환경 변수로 CUDA 스트림 스케줄링 순서 보장
  • fused_weight_gradient_mlp_cuda 커널로 그래디언트를 weight.main_grad에 직접 누적 → 중간 텐서 할당 제거

VocabUtility 패딩 전략

어휘 크기를 TP 그룹 간 균등하게 분할합니다:

# megatron/core/tensor_parallel/utils.py:97-121
class VocabUtility:
    @staticmethod
    def vocab_range_from_global_vocab_size(
        global_vocab_size: int, rank: int, world_size: int
    ) -> Sequence[int]:
        """Vocab range from global vocab size."""
        per_partition_vocab_size = divide(global_vocab_size, world_size)
        return VocabUtility.vocab_range_from_per_partition_vocab_size(
            per_partition_vocab_size, rank, world_size
        )

    @staticmethod
    def vocab_range_from_per_partition_vocab_size(
        per_partition_vocab_size: int, rank, world_size: int
    ) -> Sequence[int]:
        index_f = rank * per_partition_vocab_size
        index_l = index_f + per_partition_vocab_size
        return index_f, index_l

왜 필요한가: TP 그룹 간 균등한 어휘 분할 + all-gather/reduce-scatter 효율성

Vocab Parallel Cross-Entropy

임베딩 레이어의 출력 통신을 최소화하기 위해, cross-entropy 손실을 병렬로 계산합니다:

# megatron/core/tensor_parallel/cross_entropy.py
class _VocabParallelCrossEntropy(torch.autograd.Function):
    @staticmethod
    def forward(ctx, vocab_parallel_logits, target):
        # 1. logits_max: 수치 안정성을 위한 all-reduce (MAX)
        logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
        torch.distributed.all_reduce(logits_max, op=ReduceOp.MAX)

        # 2. predicted_logits: 타겟 토큰의 logit만 all-reduce (SUM)
        # (각 랭크는 자신의 vocab 파티션에 해당하는 타겟만 처리)

        # 3. sum_exp_logits: 분할 함수 계산을 위한 all-reduce (SUM)

        # 최종: loss = log(sum_exp) - predicted_logits

통신량: $O(b \times s)$ (vocab 차원 제거) vs $O(b \times s \times v)$ (전체 logits 수집)


2.2 nanotron: 모듈러 설계

핵심 파일: src/nanotron/parallel/tensor_parallel/

nanotron은 HuggingFace에서 개발한 분산 학습 라이브러리로, Megatron-LM의 설계를 더 모듈러하게 재구성했습니다.

Quick Start

CLI:

# YAML 설정 파일 기반 실행
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 \
    run_train.py --config-file configs/llama_tp8.yaml

YAML Config:

# configs/llama_tp8.yaml
parallelism:
  dp: 1
  pp: 1
  tp: 8                    # Tensor Parallelism 크기
  pp_engine: 1f1b
  tp_mode: ALL_REDUCE      # 또는 REDUCE_SCATTER
  tp_linear_async_communication: true
  recompute_layer: false

Python API:

from nanotron.config import ParallelismArgs
from nanotron.parallel.context import ParallelContext

# 병렬화 설정
parallelism = ParallelismArgs(
    dp=1,
    pp=1,
    tp=8,
    tp_mode="ALL_REDUCE",  # 또는 "REDUCE_SCATTER"
    tp_linear_async_communication=True,
)

# ParallelContext 초기화
parallel_context = ParallelContext(
    tensor_parallel_size=parallelism.tp,
    pipeline_parallel_size=parallelism.pp,
    data_parallel_size=parallelism.dp,
)

주의사항:

  • tp_mode: ALL_REDUCE(표준 TP) vs REDUCE_SCATTER(SP 결합용)
  • tp_recompute_allgather=True로 메모리 절약 가능 (계산 증가 트레이드오프)
  • YAML 설정이 권장되는 주요 인터페이스

두 가지 TP 모드: ALL_REDUCE vs REDUCE_SCATTER

nanotron의 가장 큰 특징은 명시적인 2가지 TP 모드를 지원한다는 점입니다:

# src/nanotron/parallel/tensor_parallel/enum.py
class TensorParallelLinearMode(Enum):
    ALL_REDUCE = "all_reduce"
    REDUCE_SCATTER = "reduce_scatter"

ALL_REDUCE 모드 (lines 248-251):

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
    gathered_tensor = tensor  # 통신 없음, 입력이 이미 복제됨
    return F.linear(gathered_tensor, weight, bias)

REDUCE_SCATTER 모드 (lines 252-380):

elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
    # Forward: AllGather
    handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True)
    # Backward: ReduceScatter

언제 사용하는가:

  • ALL_REDUCE: 입력이 복제된 경우 (표준 TP)
  • REDUCE_SCATTER: 입력이 배치 차원으로 분할된 경우 (SP와 결합 시)

same_device_shard 최적화 패턴

AllGather 대기 중 자신의 데이터로 먼저 계산하여 통신을 오버랩합니다:

# src/nanotron/parallel/tensor_parallel/functional.py:305-327
# AllGather 비동기 시작
handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True)

# 출력 텐서를 before/same/after로 분할
before_shard, same_device_shard, after_shard = torch.split(
    gathered_output,
    split_size_or_sections=[
        sharded_batch_size * current_rank,
        sharded_batch_size,  # 자신의 데이터
        sharded_batch_size * (group_size - current_rank - 1),
    ],
    dim=0,
)

# AllGather 완료 전에 자신의 shard 먼저 계산
torch.mm(
    input=tensor.view(first_dims, hidden_size),
    mat2=weight.t(),
    out=same_device_shard.view(first_dims, output_size),
)

# AllGather 완료 대기
handle.wait()

# 나머지 shard 계산
if before_shard.numel() > 0:
    torch.mm(
        input=gathered_tensor[: sharded_batch_size * current_rank].view(first_dims, hidden_size),
        mat2=weight.t(),
        out=before_shard.view(first_dims, output_size),
    )
# after_shard도 동일하게 처리

성능 이점: ~33% 계산이 통신과 오버랩 (TP=3일 때)

tp_recompute_allgather 트레이드오프

nanotron의 독특한 메모리 최적화:

# src/nanotron/parallel/tensor_parallel/nn.py
class TensorParallelColumnLinear(nn.Linear):
    def __init__(self, ..., tp_recompute_allgather=False):
        self.tp_recompute_allgather = tp_recompute_allgather

Forward (메모리 절약):

if tp_recompute_allgather:
    gathered_tensor = MemoryBuffer().get("allgather", ...)  # 버퍼 재사용
    ctx.save_for_backward(tensor, weight)  # sharded tensor만 저장

Backward (재계산):

if ctx.tp_recompute_allgather:
    unsharded_tensor = MemoryBuffer().get("allgather", ...)
    handle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True)
    # AllGather 다시 수행

트레이드오프:

설정 메모리 계산
tp_recompute_allgather=False O(batch) AllGather 1회
tp_recompute_allgather=True O(batch/TP) AllGather 2회

bias 처리 차이

  • ColumnParallel: bias 분할됨 (out_features/TP)
  • RowParallel: rank 0만 bias 보유 (dist.get_rank(self.pg) == 0 and bias)

Contiguous Chunks

QKV 프로젝션처럼 여러 텐서를 하나로 퓨전할 때 유용:

qkv_contiguous_chunks = (
    config.num_attention_heads * self.d_qk,      # Q 청크
    config.num_key_value_heads * self.d_qk,      # K 청크
    config.num_key_value_heads * self.d_qk,      # V 청크
)

self.qkv_proj = TensorParallelColumnLinear(
    hidden_size, q_out + 2*kv_out,
    contiguous_chunks=qkv_contiguous_chunks,
)

각 청크 내에서 독립적으로 TP 분할이 적용됩니다.

ParallelContext: 5D 병렬화

nanotron은 5차원 병렬화를 지원합니다:

# src/nanotron/parallel/context.py
# [expert_parallel, pipeline_parallel, data_parallel, context_parallel, tensor_parallel]

self.tp_pg         # TP group
self.dp_pg         # DP group
self.pp_pg         # PP group
self.cp_pg         # Context Parallel group (시퀀스 병렬화)
self.ep_pg         # Expert Parallel group (MoE용)

2.3 DeepSpeed: 자동화 추구

핵심 파일: deepspeed/module_inject/

DeepSpeed는 자동 TP 적용에 초점을 맞춥니다. 사용자가 모델 코드를 수정하지 않아도 TP를 적용할 수 있습니다.

Quick Start

CLI:

# DeepSpeed launcher로 실행
deepspeed --num_gpus=8 train.py \
    --deepspeed \
    --deepspeed_config ds_config.json

JSON Config:

{
  "train_batch_size": 32,
  "tensor_parallel": {
    "enabled": true,
    "autotp_size": 8,
    "tp_grain_size": 64
  },
  "fp16": {
    "enabled": true
  }
}

Python API:

import deepspeed
from transformers import AutoModelForCausalLM

# 모델 로드
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

# DeepSpeed 초기화 (AutoTP 자동 적용)
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    config="ds_config.json",
)

# 또는 수동으로 TP 적용
from deepspeed.module_inject import auto_tp
model = auto_tp.AutoTP(model, tp_size=8)

주의사항:

  • autotp_size: 자동 TP 적용 시 분할 크기
  • tp_grain_size: 불균등 vocab 분할 시 최소 단위
  • HuggingFace Transformers 모델에 자동 적용 가능

AutoTP GEM 리스트 탐지 로직

GEM(General Embedding/Matrix) 리스트는 all-reduce가 필요한 레이어를 자동 탐지합니다:

# deepspeed/module_inject/auto_tp.py:307-338
def tp_parser(model):
    for module in module_list:
        for key, submodule in module._modules.items():
            if isinstance(submodule, nn.Linear):
                layer_list = layer_list + ["." + key]
            elif isinstance(submodule, nn.LayerNorm) or key in norm_layer_name_list:
                layer_list = layer_list + ["ln"]

        for i, layer in enumerate(layer_list):
            if layer == 'ln':
                if layer_list[i - 1] != 'ln':
                    gem_list.append(layer_list[i - 1])  # LN 직전 = all-reduce 필요
            elif 'out_proj' in layer:
                gem_list.append(layer)
            elif 'down_proj' in layer:
                gem_list.append(layer)

GEM = General Embedding/Matrix: LayerNorm 직전 레이어, out_proj, down_proj 등 all-reduce가 필요한 레이어 목록

8종 Fused QKV 포맷 핸들링

다양한 모델이 QKV를 다르게 퓨전합니다:

# deepspeed/module_inject/fusedqkv_utils.py:34-46
fused_type_dict = {
    'CodeGenBlock': 'codegentype',   # [q(1),q(2),...,k(1),k(2),...,v(1),v(2),...]
    'BloomBlock': 'bloomtype',        # [q(1),k(1),v(1),q(2),k(2),v(2),...]
    'GLMBlock': 'glmtype',            # [Q,Q,...,K,K,...,V,V,...]
    "MPTBlock": 'glmtype',
    "BaichuanLayer": 'glmtype',
    "QWenBlock": 'qwentype',
    "FalconDecoderLayer": 'bloomtype',
    "GPTBigCodeBlock": 'bigcodetype',
    "Phi3DecoderLayer": "phi3type",   # Rotary 임베딩 위치 분리
}
모델 포맷 레이아웃
Bloom/Falcon bloomtype [q1,k1,v1,q2,k2,v2,...] (interleaved)
ChatGLM/MPT glmtype [Q,Q,...,K,K,...,V,V,...] (stacked)
CodeGen codegentype 멀티블록 레이아웃
Phi3 phi3type Rotary 임베딩 위치 분리

DeepSpeed는 각 포맷에 맞는 전치/분할 로직을 자동 적용합니다.

SubParamLinearLayer 불균등 파티셔닝

GQA(Grouped-Query Attention)처럼 Q와 KV 헤드 수가 다를 때:

# deepspeed/module_inject/layers.py
class SubParamLinearLayer(TensorParallel_Layer):
    def __init__(self, module, mp_group, shape, partition_dim=0):
        # shape = ((q_size, k_size, v_size), -1) for GQA
        # 각 서브파라미터를 독립적으로 분할

GQA 예시:

shape = ((4096, 1024, 1024), -1)  # Q: 4096, K: 1024, V: 1024
# 각 sub-param을 독립적으로 분할
sub_params = torch.split(tensor, subparam_sizes, dim=partition_dim)
partitioned = [torch.chunk(sp, tp_size, dim=0)[tp_idx] for sp in sub_params]

tp_grain_size 불균등 분할

어휘 크기가 TP 크기로 나누어 떨어지지 않을 때:

# deepspeed/module_inject/tp_shard.py:47-67
def get_shard_size(total_size, mp_size, name=None, rank=None):
    if total_size >= tp_grain_size:
        grain_size = total_size // tp_grain_size
        return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * tp_grain_size
    else:
        return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0)

예시: total=4096, tp_size=8, grain_size=128 → 각 랭크 512 토큰 (균등 분할)


2.4 torchtitan: PyTorch Native

핵심 파일: torchtitan/distributed/

torchtitan은 PyTorch의 DTensorDeviceMesh API를 사용하여 TP를 구현합니다. 커스텀 autograd 함수 없이 선언적으로 병렬화를 정의합니다.

Quick Start

CLI:

# TOML 설정 파일 기반 실행
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" \
    ./run_train.sh

TOML Config:

[parallelism]
tensor_parallel_degree = 8
enable_loss_parallel = true
pipeline_parallel_degree = 1
data_parallel_shard_degree = -1  # auto
data_parallel_replicate_degree = 1
context_parallel_degree = 1

[training]
enable_async_tensor_parallel = true

Python API:

from torchtitan.distributed import ParallelDims
from torch.distributed.device_mesh import init_device_mesh

# 병렬화 차원 정의
parallel_dims = ParallelDims(
    tp=8,
    pp=1,
    dp_shard=-1,  # auto
    dp_replicate=1,
    cp=1,
)

# DeviceMesh 초기화
world_mesh = parallel_dims.build_mesh(device_type="cuda")

# TP용 sub-mesh 추출
tp_mesh = world_mesh["tp"]

주의사항:

  • enable_loss_parallel=True: vocab-parallel cross-entropy로 메모리 절약
  • enable_async_tensor_parallel=True: torch.compile과 함께 비동기 TP 활성화
  • DTensor 기반이므로 torch.compile 완전 호환

선언적 병렬화 계획

# torchtitan/models/llama3/infra/parallelize.py:161-248
from torch.distributed.tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
    ColwiseParallel, RowwiseParallel, SequenceParallel,
    parallelize_module, PrepareModuleInput,
)

def apply_tp(model, tp_mesh, loss_parallel, enable_float8_tensorwise_tp, cp_enabled):
    # 1. 임베딩, 정규화, 출력 레이어 병렬화
    parallelize_module(
        model,
        tp_mesh,
        {
            "tok_embeddings": RowwiseParallel(
                input_layouts=Replicate(),
                output_layouts=Shard(1),
            ),
            "norm": SequenceParallel(),
            "output": ColwiseParallel(
                input_layouts=Shard(1),
                output_layouts=Shard(-1) if loss_parallel else Replicate(),
                use_local_output=not loss_parallel,
            ),
        },
    )

    # 2. 각 Transformer 블록 병렬화
    for transformer_block in model.layers.values():
        layer_plan = {
            "attention_norm": SequenceParallel(),
            "attention": PrepareModuleInput(
                input_layouts=(Shard(1), None, None, None),      # 현재: sequence dim 분할
                desired_input_layouts=(Replicate(), None, None, None),  # 목표: 복제
            ),
            "attention.wq": ColwiseParallel(),
            "attention.wk": ColwiseParallel(),
            "attention.wv": ColwiseParallel(),
            "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
            "ffn_norm": SequenceParallel(),
            "feed_forward": PrepareModuleInput(
                input_layouts=(Shard(1),),
                desired_input_layouts=(Replicate(),),
            ),
            "feed_forward.w1": ColwiseParallel(),
            "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
            "feed_forward.w3": ColwiseParallel(),
        }
        parallelize_module(transformer_block, tp_mesh, layer_plan)

핵심 차이: 명시적 f/g 연산자 대신 DTensor가 레이아웃 변환을 자동 처리

loss_parallel 컨텍스트 매니저

# torchtitan/distributed/utils.py
if enable_loss_parallel:
    stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())

효과:

  • 출력 레이어가 Shard(-1) (vocab dim 분할) 유지
  • all-gather 지연으로 메모리 절약
  • cross-entropy 계산 시 자동 분산 처리

Float8 TP 지원

# torchtitan/models/llama3/infra/parallelize.py:192-210
if enable_float8_tensorwise_tp:
    from torchao.float8.float8_tensor_parallel import (
        Float8ColwiseParallel,
        Float8RowwiseParallel,
        PrepareFloat8ModuleInput,
    )
    rowwise_parallel, colwise_parallel, prepare_module_input = (
        Float8RowwiseParallel,
        Float8ColwiseParallel,
        PrepareFloat8ModuleInput,
    )

제약: Tensorwise float8만 TP 지원, rowwise는 표준 TP 사용

레이아웃 명세

from torch.distributed.tensor import Replicate, Shard

Replicate()   # 전체 텐서를 모든 랭크에 복제
Shard(0)      # 0번 차원으로 분할 (batch)
Shard(1)      # 1번 차원으로 분할 (sequence/hidden)
Shard(-1)     # 마지막 차원으로 분할 (vocab dim)

torch.compile 호환

DTensor 기반 구현은 torch.compile과 자연스럽게 통합됩니다:

# Async TP 활성화
torch._inductor.config._micro_pipeline_tp = True

3. 추론 프레임워크

추론 프레임워크는 학습과 다른 최적화 방향을 가집니다:

  • 그래디언트 없음: backward pass 최적화 불필요
  • 레이턴시 중심: 단일 토큰 생성 시간이 중요
  • Weight Loading: 모델 로딩 시 샤딩 적용

3.1 vLLM: 추론 최적화의 정석

핵심 파일: vllm/distributed/, vllm/model_executor/layers/

Quick Start

CLI (권장):

# OpenAI-compatible 서버 시작
vllm serve meta-llama/Llama-3.1-8B-Instruct \
    --tensor-parallel-size 8 \
    --max-model-len 8192 \
    --gpu-memory-utilization 0.9

# 또는 단축 옵션 사용
vllm serve meta-llama/Llama-3.1-8B-Instruct -tp 8

Python API:

from vllm import LLM, SamplingParams

# 모델 로드 (TP 자동 적용)
llm = LLM(
    model="meta-llama/Llama-3.1-8B-Instruct",
    tensor_parallel_size=8,
    max_model_len=8192,
    gpu_memory_utilization=0.9,
)

# 추론
sampling_params = SamplingParams(temperature=0.7, max_tokens=256)
outputs = llm.generate(["Hello, world!"], sampling_params)

환경 변수:

export VLLM_WORKER_MULTIPROC_METHOD=spawn  # 멀티프로세스 방식
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

주의사항:

  • CLI가 가장 권장되는 인터페이스
  • TP 크기는 GPU 수의 약수여야 함
  • --enforce-eager 플래그로 CUDA graph 비활성화 가능 (디버깅용)

GroupCoordinator 상세 구현

# vllm/distributed/parallel_state.py:276-505
class GroupCoordinator:
    """PyTorch ProcessGroup wrapper for a group of processes."""

    def __init__(self, group_ranks, local_rank, torch_distributed_backend, ...):
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)  # custom op에서 조회할 수 있도록 등록

        # CPU와 device 통신 그룹 분리
        self.cpu_group = torch.distributed.new_group(ranks, backend="gloo")
        self.device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend)

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

        if self.use_custom_op_call:
            return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
        else:
            return self._all_reduce_out_place(input_)

핵심: 단일 GPU 시 통신 스킵, custom op으로 torch.compile 지원

Custom Op 등록

# vllm/distributed/parallel_state.py:248-273
direct_register_custom_op(
    op_name="all_reduce",
    op_func=all_reduce,
    fake_impl=all_reduce_fake,  # torch.compile용 shape 추론
)

direct_register_custom_op(
    op_name="reduce_scatter",
    op_func=reduce_scatter,
    fake_impl=reduce_scatter_fake,
)

direct_register_custom_op(
    op_name="all_gather",
    op_func=all_gather,
    fake_impl=all_gather_fake,
)

Weight Loader 패턴

추론에서는 모델 로딩 시 weight를 분할합니다 (런타임 아님):

# vllm/model_executor/layers/linear.py:551-586
class ColumnParallelLinear(LinearBase):
    def weight_loader(self, param, loaded_weight):
        output_dim = getattr(param, "output_dim", None)
        is_sharded_weight = getattr(param, "is_sharded_weight", False)

        if output_dim is not None and not is_sharded_weight:
            shard_size = param_data.shape[output_dim]
            start_idx = self.tp_rank * shard_size
            # 로드 시점에 weight를 분할
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)

        param_data.copy_(loaded_weight)

장점: 런타임에 분할 연산 없음, 추론 레이턴시 최소화

QKVParallelLinear GQA/MQA 처리

# vllm/model_executor/layers/linear.py:954-962
class QKVParallelLinear(ColumnParallelLinear):
    def __init__(self, hidden_size, head_size, total_num_heads, total_num_kv_heads, ...):
        tp_size = get_tensor_model_parallel_world_size()

        if tp_size >= self.total_num_kv_heads:
            # KV 헤드 < TP 크기: KV 헤드 복제
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
        else:
            # KV 헤드 >= TP 크기: KV 헤드 분할
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1

GQA(Grouped-Query Attention)에서 KV 헤드 수가 Query 헤드보다 적을 때, 자동으로 복제/분할을 결정합니다.

input_is_parallel / reduce_results 플래그

# vllm/model_executor/layers/linear.py (RowParallelLinear)
def forward(self, input_):
    if self.input_is_parallel:
        input_parallel = input_  # scatter 스킵, 이미 분할되어 있음
    else:
        splitted = split_tensor_along_last_dim(input_, self.tp_size)
        input_parallel = splitted[self.tp_rank]

    output_parallel = self.quant_method.apply(self, input_parallel, bias)

    if self.reduce_results and self.tp_size > 1:
        output = tensor_model_parallel_all_reduce(output_parallel)

ColumnParallel의 출력을 바로 RowParallel에 연결할 때, 중간 scatter를 스킵합니다.


3.2 SGLang: 통신 오케스트레이션

핵심 파일: sglang/srt/layers/

SGLang은 vLLM 기반이지만, 구조화된 출력(structured output) 지원과 함께 독자적인 최적화를 추가했습니다.

Quick Start

CLI (권장):

# OpenAI-compatible 서버 시작
python -m sglang.launch_server \
    --model-path meta-llama/Llama-3.1-8B-Instruct \
    --tp-size 8 \
    --port 30000

# 또는 단축 명령
sglang serve meta-llama/Llama-3.1-8B-Instruct --tp 8

Python API:

import sglang as sgl

# 엔진 초기화
engine = sgl.Engine(
    model_path="meta-llama/Llama-3.1-8B-Instruct",
    tp_size=8,
)

# 추론
outputs = engine.generate(
    prompt="Hello, world!",
    sampling_params={"temperature": 0.7, "max_new_tokens": 256},
)

고급 설정:

engine = sgl.Engine(
    model_path="meta-llama/Llama-3.1-8B-Instruct",
    tp_size=8,
    dp_size=2,                      # DP+TP 이중 병렬화
    enable_flashinfer_allreduce_fusion=True,  # Hopper/Blackwell 최적화
)

주의사항:

  • --dp-size--tp-size 조합으로 DP+TP 이중 병렬화 가능
  • --enable-flashinfer-allreduce-fusion: SM90+ (Hopper/Blackwell) 전용 최적화
  • 구조화된 출력 시 --grammar 플래그 사용

ScatterMode 상세 설명

# sglang/srt/layers/communicator.py:102-120
class ScatterMode(Enum):
    """
    Suppose we have TP=4, DP=2, enable-dp-attention, and the system handles seq a,b,c,d
    Model input/output: [ab, ab, cd, cd] for four ranks respectively
    """
    SCATTERED = auto()      # [a, b, c, d] - 각 랭크가 자신만
    TP_ATTN_FULL = auto()   # [ab, ab, cd, cd] - TP attn 그룹 내 복제
    FULL = auto()           # [abcd, abcd, abcd, abcd] - 전체 복제

    @staticmethod
    def model_input_output():
        """The scatter mode for model forward pass input and output data"""
        if is_nsa_enable_prefill_cp():
            return ScatterMode.SCATTERED
        return ScatterMode.TP_ATTN_FULL

FlashInfer AllReduce Fusion

# sglang/srt/layers/communicator.py:89-99
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048

def apply_flashinfer_allreduce_fusion(batch_size: int):
    return (
        (_is_sm90_supported or _is_sm100_supported)  # Hopper/Blackwell
        and _is_flashinfer_available
        and batch_size > 0
        and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE
        and not is_dp_attention_enabled()
        and get_global_server_args().enable_flashinfer_allreduce_fusion
    )

LayerCommunicator

SGLang의 핵심 혁신:

# sglang/srt/layers/communicator.py:336-380
class LayerCommunicator:
    def __init__(
        self,
        layer_scatter_modes: LayerScatterModes,
        input_layernorm: torch.nn.Module,
        post_attention_layernorm: torch.nn.Module,
        allow_reduce_scatter: bool = False,
        is_last_layer: bool = False,
        qkv_latent_func: Optional[Callable] = None,
    ):
        self.layer_scatter_modes = layer_scatter_modes
        # 레이어별로 최적의 통신 패턴 결정
        self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
            input_mode=self.layer_scatter_modes.layer_input_mode,
            output_mode=self.layer_scatter_modes.attn_mode,
            context=self._context,
        )

LayerScatterModes 구조:

@dataclass
class LayerScatterModes:
    layer_input_mode: ScatterMode    # 레이어 입력
    attn_mode: ScatterMode           # Attention 계산
    mlp_mode: ScatterMode            # MLP 계산
    middle_residual_mode: ScatterMode # 중간 residual
    layer_output_mode: ScatterMode   # 레이어 출력

DP+TP 이중 병렬화

# sglang/srt/layers/dp_attention.py
def get_attention_tp_group() -> GroupCoordinator:
    return _ATTN_TP_GROUP  # Attention 전용 TP 그룹

def get_attention_dp_size() -> int:
    return _ATTN_DP_SIZE  # DP 크기

SGLang은 Attention에 대해 별도의 통신 그룹을 유지하여, DP와 TP를 효율적으로 결합합니다.

CommunicateContext

# sglang/srt/layers/communicator.py:608-641
@dataclass
class CommunicateContext:
    process_group_sizes: Dict[ScatterMode, int]
    attn_tp_rank: int
    attn_tp_size: int
    attn_dp_size: int
    tp_size: int
    tp_rank: int

    @classmethod
    def init_new(cls):
        process_group_sizes = {
            ScatterMode.SCATTERED: 1,
            ScatterMode.TP_ATTN_FULL: attn_tp_size,
            ScatterMode.FULL: tp_size,
        }
        return cls(...)

4. 구현 패턴 비교

4.1 통신 패턴 비교 다이어그램

ColumnParallel (f 연산자):
Forward:  X ──[identity]──> X @ A_i ──> Y_i
Backward: dX <──[all-reduce]── dL/dY_i

RowParallel (g 연산자):
Forward:  Y_i ──[all-reduce]──> Y = Σ Y_i
Backward: dY <──[identity]── dL/dY_i

4.2 f/g 연산자 구현 비교

프레임워크 구현 방식 특징
Megatron-LM torch.autograd.Function 원조, 가장 정교한 비동기 오버랩
nanotron Differentiable* 클래스 모듈러, 2가지 TP 모드
DeepSpeed ColumnParallel, RowParallel 클래스 AutoTP와 통합
torchtitan DTensor 자동 처리 선언적, torch.compile 친화
vLLM GroupCoordinator 래퍼 추론 최적화, custom op
SGLang LayerCommunicator 레이어별 통신 오케스트레이션

4.3 프레임워크별 핵심 차이 표

측면 Megatron-LM nanotron DeepSpeed torchtitan vLLM SGLang
API 스타일 CLI + Python YAML 중심 JSON config TOML + DTensor CLI 최우선 CLI + Python
추상화 autograd.Function Differentiable* AutoTP 주입 DTensor GroupCoordinator LayerCommunicator
통신 오버랩 CUDA 스트림 same_device_shard AsyncColumnParallel torch.compile - FlashInfer 퓨전
GQA 지원 암시적 contiguous_chunks SubParam shape 파라미터 num_kv_head_replicas 동일
고유 최적화 wgrad 퓨전 커널 tp_recompute_allgather tp_grain_size loss_parallel custom op ScatterMode

4.4 MLP 병렬화

모든 프레임워크가 동일한 기본 패턴을 따릅니다:

┌─────────────────────────────────────────────────────────────┐
│  Input X [batch, seq, hidden]                               │
│      │                                                      │
│      ▼                                                      │
│  ColumnParallel (fc1/gate_up)                              │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  X @ A_i  (각 랭크가 A의 1/p 열을 담당)              │   │
│  │  통신 없음                                          │   │
│  └─────────────────────────────────────────────────────┘   │
│      │                                                      │
│      ▼ [batch, seq, 4*hidden/p]                            │
│  Activation (GeLU/SiLU)                                    │
│      │                                                      │
│      ▼                                                      │
│  RowParallel (fc2/down)                                    │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  Y_i @ B_i  (각 랭크가 B의 1/p 행을 담당)            │   │
│  │  All-Reduce: Z = Σ Y_i @ B_i                        │   │
│  └─────────────────────────────────────────────────────┘   │
│      │                                                      │
│      ▼ [batch, seq, hidden]                                │
│  Output Z (모든 랭크에 복제)                                │
└─────────────────────────────────────────────────────────────┘

차이점:

  • Megatron-LM: 비동기 all-reduce로 backward 오버랩
  • nanotron: REDUCE_SCATTER 모드로 Sequence Parallel과 결합 가능
  • DeepSpeed: GateUpPack_LinearLayer로 게이트+업 자동 퓨전
  • torchtitan: DTensor가 레이아웃 변환 자동 처리
  • vLLM/SGLang: gather_output=False로 중간 통신 스킵

4.5 Cross-Entropy 병렬화

학습 프레임워크(Megatron-LM, nanotron)는 vocab-parallel cross-entropy를 구현합니다:

def vocab_parallel_cross_entropy(vocab_parallel_logits, target):
    # 각 랭크: [batch*seq, vocab/p] logits 보유

    # 1. 수치 안정성: max logit을 all-reduce (MAX)
    logits_max = vocab_parallel_logits.max(dim=-1)
    all_reduce(logits_max, op=MAX)

    # 2. 타겟 logit: 해당 랭크만 값을 가짐, all-reduce (SUM)
    # 3. 분할 함수: exp(logits).sum()을 all-reduce (SUM)

    # 통신량: O(batch * seq) - vocab 차원 제거됨

추론 프레임워크에서는 cross-entropy가 필요 없으므로 구현하지 않습니다.


5. 성능 고려사항

5.1 통신 복잡도

연산 Forward 통신 Backward 통신
ColumnParallel 없음 All-Reduce
RowParallel All-Reduce 없음
VocabParallel Embed All-Reduce 없음
VocabParallel CE 3× All-Reduce 없음

Transformer 레이어당: 4× All-Reduce (학습 시 forward + backward)

5.2 통신-계산 오버랩

비동기 통신은 다음 조건에서 효과적입니다:

  • CUDA_DEVICE_MAX_CONNECTIONS=1: 커널 스케줄링 순서 보장
  • 큰 hidden dimension: 계산 시간 > 통신 시간
  • Gradient Accumulation Fusion: CUDA 커널로 그래디언트 누적

5.3 메모리 사용량

프레임워크 메모리 절약 기법
Megatron-LM Activation Checkpointing, Fused Kernels
nanotron tp_recompute_allgather 플래그
DeepSpeed ZeRO-3 + TP 결합
torchtitan FSDP2 + TP 결합
vLLM KV Cache 최적화
SGLang Input Scattered Mode

6. 어떤 프레임워크를 선택할 것인가?

6.1 학습 목적

시나리오 추천 프레임워크 이유
대규모 사전학습 Megatron-LM 가장 성숙한 3D 병렬화, 최적화된 커널
연구/실험 nanotron 모듈러 설계, 빠른 프로토타이핑
기존 HF 모델 학습 DeepSpeed AutoTP, ZeRO 통합
PyTorch 생태계 통합 torchtitan DTensor 네이티브, torch.compile

6.2 추론 목적

시나리오 추천 프레임워크 이유
일반 LLM 서빙 vLLM PagedAttention, 성숙한 생태계
구조화된 출력 SGLang Grammar 지원, LayerCommunicator
커스텀 최적화 둘 다 가능 코드베이스 이해 후 확장

7. 결론

Tensor Parallelism의 이론은 2019년 Megatron-LM 논문에서 정립되었지만, 구현은 각 프레임워크의 목적에 따라 크게 달라집니다.

핵심 교훈:

  1. f/g 켤레 연산자는 모든 구현의 기초이지만, DTensor처럼 추상화할 수도 있습니다.
  2. 학습 vs 추론은 최적화 방향이 다릅니다: 그래디언트 통신 vs weight loading
  3. 통신-계산 오버랩은 학습에서 핵심이지만, 추론에서는 덜 중요합니다.
  4. 자동화 수준은 생산성(DeepSpeed AutoTP)과 제어력(Megatron-LM) 사이의 트레이드오프입니다.

실제 시스템을 구축할 때는 이론뿐 아니라 각 프레임워크의 구현 세부사항을 이해하는 것이 중요합니다. 이 글이 그 이해에 도움이 되길 바랍니다.


부록: 핵심 파일 참조

프레임워크 핵심 파일 주요 내용
Megatron-LM megatron/core/tensor_parallel/layers.py ColumnParallel, RowParallel
Megatron-LM megatron/core/tensor_parallel/mappings.py f/g 연산자
nanotron src/nanotron/parallel/tensor_parallel/nn.py TP 레이어 정의
nanotron src/nanotron/parallel/tensor_parallel/functional.py 비동기 통신 구현
DeepSpeed deepspeed/module_inject/layers.py Linear 레이어 래퍼
DeepSpeed deepspeed/module_inject/auto_tp.py AutoTP 파서
DeepSpeed deepspeed/module_inject/fusedqkv_utils.py Fused QKV 포맷 핸들링
DeepSpeed deepspeed/module_inject/tp_shard.py 불균등 분할 로직
torchtitan torchtitan/distributed/tensor_parallel.py DTensor 기반 TP
torchtitan torchtitan/models/llama3/infra/parallelize.py 병렬화 계획
vLLM vllm/distributed/parallel_state.py GroupCoordinator
vLLM vllm/model_executor/layers/linear.py TP Linear 레이어
SGLang sglang/srt/layers/linear.py TP Linear 레이어
SGLang sglang/srt/layers/communicator.py LayerCommunicator



    Enjoy Reading This Article?

    Here are some more articles you might like to read next:

  • Flash Attention 3
  • Flash Attention 2
  • Flash Attention
  • Unified Sequence Parallelism
  • Reducing Activation Recomputation in Large Transformer Models
  • Stay updated — subscribe via RSS




    Leave a Comment

    Found this useful or have questions? Sign in with GitHub to join the conversation.