0. Series Overview

This Article Upstream Output Downstream
Article 6/10 Article 05 – Dataset Ready final_lora/, 15 checkpoints Article 07 – Reading Logs · Article 08 – Verification

This is the only code snippet that modifies LoRA weights. After running, you should see:

1
✅ Single GPU fine-tuning complete! LoRA weights: ./output/lora_elderly_single/final_lora

(Matching the last line of all_logs.log, with slight wording differences in the path.)


1. The Actual Problem Solved

Steps 5–7 handle:

  1. Merging HuggingFace TrainingArguments with TRL‑specific fields (dataset_text_field, max_length) into SFTConfig
  2. Creating SFTTrainer, where LoRA is injected at that point
  3. Running train() for 750 steps, and save_pretrained to save only the adapter

Many people get stuck on TRL 1.x API changes: processing_class replaces tokenizer, max_length replaces max_seq_length.


2. Implementation Locations

Code Block Lines
SFTConfig(...) 204–221
SFTTrainer(...) 226–233
trainer.train() 240
save_pretrained 244–246
TrainingProgressCallback 82–99

Output directory structure:

1
2
3
output/lora_elderly_single/
├── checkpoint-50/ ... checkpoint-750/ # Each contains optimizer, ~81MB
└── final_lora/ # Only adapter + tokenizer, ~41MB

3. SFTConfig Parameter Breakdown

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
training_args = SFTConfig(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
learning_rate=LEARNING_RATE,
num_train_epochs=EPOCHS,
bf16=True,
logging_steps=1,
logging_first_step=True,
save_steps=50,
optim="paged_adamw_8bit",
report_to="none",
remove_unused_columns=True,
disable_tqdm=False,
logging_strategy="steps",
dataset_text_field="text",
max_length=MAX_SEQ_LEN,
)
Parameter Value Explanation
per_device_train_batch_size=2 micro-batch 2 samples per forward pass
gradient_accumulation_steps=2 Accumulation 2×2=4 before optimizer.step
learning_rate=2e-4 Common for LoRA Linear decay to ~0
bf16=True Mixed precision Matches the loading dtype
logging_steps=1 Log every step Works with the callback print
save_steps=50 Checkpoints 750/50=15 checkpoints
optim="paged_adamw_8bit" 8‑bit Adam Saves memory for optimizer states
dataset_text_field="text" Field name Corresponds to load_jsonl_data
max_length=512 Truncation Truncates long samples at the end
report_to="none" No uploading No W&B

No eval_strategy set: no validation set, so logs will not contain eval_loss (see Article 07).


4. SFTTrainer Construction and train()

1
2
3
4
5
6
7
8
9
10
11
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=lora_config,
processing_class=tokenizer,
args=training_args,
callbacks=[TrainingProgressCallback()],
)

trainer.model.print_trainable_parameters()
trainer.train()

4.1 Injection Timing

SFTTrainer internally calls PEFT, attaching LoRA layers to the target_modules. At this point it prints:

1
trainable params: 10,616,832 || all params: 4,216,368,128 || trainable%: 0.2518

4.2 What Happens Inside the Training Loop (Conceptual)

1
2
3
4
5
6
7
8
9
10
11
12
sequenceDiagram
participant T as SFTTrainer
participant M as Base+LoRA
participant O as paged_adamw_8bit

loop Each step (750 total)
T->>M: forward(text batch)
M-->>T: loss (causal LM)
T->>M: backward (only LoRA has grad)
Note over T: Accumulates every 2 micro‑steps
T->>O: step updates LoRA weights
end

The base (W) has no gradients; only (A,B) are updated.

4.3 Dual Logging

In the same step you will see:

  1. tqdm progress bar (3.31s/it)
  2. Callback line: [Progress 33.3%] Step 250/750 | Epoch 1.00 | loss=0.2402
  3. Transformers JSON log: {'loss': '0.24', 'mean_token_accuracy': '0.957', ...}

When reviewing, rely on the Callback line + summary line.


5. Step 7: Saving

1
2
3
save_path = f"{OUTPUT_DIR}/final_lora"
trainer.model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

This saves the PeftModel’s adapter, not the merged full weights. For inference:

1
2
base = AutoModelForCausalLM.from_pretrained(BASE_MODEL, ...)
model = PeftModel.from_pretrained(base, LORA_PATH)

For vLLM, use --lora-modules elderly=./output/.../final_lora (Article 10).


6. Checkpoints and Resumption

checkpoint-750/trainer_state.json contains:

  • global_step: 750
  • log_history: loss for each step
  • optimizer / scheduler states (for --resume_from_checkpoint)

For everyday verification use final_lora; intermediate checkpoints are only needed if you need to resume interrupted training.


7. OOM Contingency (in Priority Order)

1
2
3
4
5
6
7
8
BATCH_SIZE = 1                    # Halve memory usage
GRADIENT_ACCUMULATION_STEPS = 4 # Keep effective batch=4
# or
MAX_SEQ_LEN = 256
# or
LORA_R = 4
# Last resort
optim = "adamw_torch" # Drop 8‑bit optimizer, memory usage actually increases

After changing parameters, the number of steps (750) will change (if batch or epochs are altered), so do not compare directly with old logs.


8. Pitfalls

Pitfall 1: max_seq_length from TRL 0.x tutorials used in SFTConfig
TRL 1.x uses max_length. Using the wrong parameter name will silently use the default, so you might think you’re training with 512 but actually use 1024.

Pitfall 2: Calling get_peft_model twice
See Article 05 – leads to abnormal trainable%.

Pitfall 3: save_steps set too low
Saving every 10 steps fills up disk space and slows down I/O. 50 steps is reasonable for this project.

Pitfall 4: After training, only copying checkpoint-750, forgetting final_lora
After trainer.train() finishes, save_pretrained(final_lora) is the clean adapter; checkpoints carry optimizer states and are larger.


9. Summary

  1. SFTConfig manages hyperparameters, max_length, and dataset_text_field.
  2. SFTTrainer + peft_config is the only point where LoRA is injected.
  3. 750 steps / 41 min / train_loss 0.2587 measured on a V100.
  4. final_lora is the path for deployment and verification; checkpoints are for resuming and loss curves.
  5. No validation set – effectiveness is verified in Article 08 (inference).

Appendix: TrainingProgressCallback.on_log

1
2
3
4
5
6
7
8
9
10
11
12
# LoRA_Demo/train_lora_single.py lines 88‑96

def on_log(self, args, state, control, logs=None, **kwargs):
if not logs or "loss" not in logs:
return
pct = (state.global_step / state.max_steps * 100) if state.max_steps else 0
print(
f"[Progress {pct:5.1f}%] Step {state.global_step}/{state.max_steps} | "
f"Epoch {logs.get('epoch', state.epoch):.2f} | "
f"loss={logs['loss']:.4f} | lr={logs.get('learning_rate', 'N/A')}"
)
# With logging_steps=1, this generates ~750 lines of progress in all_logs.log

Series Navigation

Article Link
Previous 05 · SFT in Practice (Part 1)
Next 07 · Training Curves
Index README

← Back to LoRA for Elderly Companion Series