OOM error solution - Llama2 70b
Env
- OS: ubuntu 22.04
- python: 3.10.12
- transformers: 4.34.0
- peft: 0.5.0
- bitsandbytes: 0.41.1
- accelerate: 0.23.0
- torch: 2.1.0+cu118
- nvidia-driver-version: 525.147.05
- GPU: A100 x 8ea (80GBx8)
Error
- transformer.Trainer로 fine-tuning 작업 진행 시, 몇개의 스텝만 진행되다가 OOM 발생
- memory leakage 현상, memory가 stable되지 않고 스텝별 gpu allocated memory가 매번 달라지는 현상
Solution
- transformers-#20287 이슈 탐색후, casting문제일 수도 있어 half-precision으로 casting 방식 변경. (bf16->fp16)
- logging_steps,- warmup_steps하이퍼파라미터 적은 수로 변경
- get_peft_model전,- prepare_model_for_kbit_training으로 wrapping 시도
$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  Off  | 00000000:27:00.0 Off |                    0 |
| N/A   31C    P0    68W / 400W |  15251MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  Off  | 00000000:2A:00.0 Off |                    0 |
| N/A   27C    P0    68W / 400W |  15961MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM...  Off  | 00000000:51:00.0 Off |                    0 |
| N/A   28C    P0    67W / 400W |  15963MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA A100-SXM...  Off  | 00000000:57:00.0 Off |                    0 |
| N/A   32C    P0    67W / 400W |  15969MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   4  NVIDIA A100-SXM...  Off  | 00000000:88:00.0 Off |                    0 |
| N/A   41C    P0   314W / 400W |  15971MiB / 81920MiB |     64%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   5  NVIDIA A100-SXM...  Off  | 00000000:8E:00.0 Off |                    0 |
| N/A   31C    P0    79W / 400W |  15889MiB / 81920MiB |     28%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   6  NVIDIA A100-SXM...  Off  | 00000000:A5:00.0 Off |                    0 |
| N/A   29C    P0    77W / 400W |  15779MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   7  NVIDIA A100-SXM...  Off  | 00000000:A8:00.0 Off |                    0 |
| N/A   33C    P0    83W / 400W |  18915MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A   2028046      C   ...e/user/lora/bin/python    15248MiB |
|    1   N/A  N/A   2028046      C   ...e/user/lora/bin/python    15958MiB |
|    2   N/A  N/A   2028046      C   ...e/user/lora/bin/python    15960MiB |
|    3   N/A  N/A   2028046      C   ...e/user/lora/bin/python    15966MiB |
|    4   N/A  N/A   2028046      C   ...e/user/lora/bin/python    15968MiB |
|    5   N/A  N/A   2028046      C   ...e/user/lora/bin/python    15886MiB |
|    6   N/A  N/A   2028046      C   ...e/user/lora/bin/python    15776MiB |
|    7   N/A  N/A   2028046      C   ...e/user/lora/bin/python    18912MiB |
+-----------------------------------------------------------------------------+
source-code
@dataclass
class ScriptArgument:
    ## model argument
    model_name: Optional[str] = field(default = "beomi/llama-2-ko-70b", metadata = {"help" : "the model name"})
    
    ## train arguments
    batch_size: Optional[int] = field(default = 2, metadata = {"help" : "the batch size"})
    gradient_accumulation_steps: Optional[int] = field(default = 2, metadata = {"help" : "the gradient accumulation steps"})
    num_train_epochs: Optional[int] = field(default = 1, metadata = {"help" : "the number of training epochs"})
    learning_rate: Optional[float] = field(default = 1.41e-5, metadata = {"help" : "the learning rate"})
    logging_steps: Optional[int] = field(default = 100, metadata = {"help" : "logging_steps"})
    save_steps: Optional[int] = field(default = 100, metadata = {"help" : "save_steps"})
    save_total_limit: Optional[int] = field(default = 3, metadata = {"help" : "the limitation of total save file"})
    output_dir: Optional[str] = field(default = "results_70b", metadata = {"help" : "output directory"})
    
    cutoff_len: Optional[int] = field(default = 1024, metadata = {"help" : "maximum padding length"})
    
    ## peft & lora arguments
    # use_peft = Optional[bool] = field(default = True, metadata = {"help" : "Using peft library"})
    load_in_8bit: Optional[bool] = field(default = True, metadata = {"help" : "Using 8bits"})
    
    lora_r: Optional[int] = field(default = 8, metadata = {"help" : "lora r"})
    lora_alpha: Optional[int] = field(default = 16, metadata = {"help" : "lora alpha"})
    lora_dropout: Optional[float] = field(default = 0.05, metadata = {"help" : "lora dropout that is stastically not related the performance as an original paper"})
    
script_args = ScriptArgument()
if script_args.load_in_8bit:
    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_name,
        torch_dtype = torch.float16,
        device_map = 'auto',
        load_in_8bit = True,
        trust_remote_code = True
    )
else:
    model = AutoModelForCausalLM.from_pretrained(
    script_args.model_name,
    torch_dtype = torch.float16,
    device_map = 'auto',
    trust_remote_code = True
    )
    
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = (0)      # == unk_token_id (<unk>)
config = LoraConfig(
    r = script_args.lora_r,
    lora_alpha = script_args.lora_alpha,
    lora_dropout = script_args.lora_dropout,
    target_modules = ["q_proj", "v_proj"],  ## find_all_linear_names(model)
    bias = "none",
    task_type = TaskType.CAUSAL_LM
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, config)
model.is_parallelizable = True
model.model_parallel = True
trainer = Trainer(
    model = model,
    train_dataset = data['train'],
    # eval_dataset = val_data,
    args = TrainingArguments(
        per_device_train_batch_size = script_args.batch_size,
        gradient_accumulation_steps = script_args.gradient_accumulation_steps,
        learning_rate = script_args.learning_rate,
        num_train_epochs = script_args.num_train_epochs,
        optim = 'adamw_torch',
        fp16 = True,
        warmup_steps = 10,
        logging_steps = 10,
        save_steps = 10,
        save_total_limit = script_args.save_total_limit,
        save_strategy = "steps",
        output_dir = script_args.output_dir,
        remove_unused_columns = False
    ),
    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
)
model.config.use_cache = False
if torch.__version__ >= '2' and sys.platform != "win32":
    model = torch.compile(model)
trainer.train()
| Step | Training-Loss | 
|---|---|
| 10 | 2.706500 | 
| 20 | 2.482400 | 
| 30 | 2.715400 |