# peft/tuners/lora.py - 601 line
def merge_and_unload(self):
r"""
This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model as a standalone model.
Example:
>>> from transformers import AutoModelForCausalLM
>>> from peft import PeftModel
>>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b")
>>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample"
>>> model = PeftModel.from_pretrained(base_model, peft_model_id)
>>> merged_model = model.merge_and_unload()
"""
return self._unload_and_optionally_merge()
# peft/tuners/lora.py - 438 line
def _unload_and_optionally_merge(self, merge=True):
if getattr(self.model, "is_loaded_in_8bit", False) or getattr(self.model, "is_loaded_in_4bit", False):
raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode")
key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
for key in key_list:
try:
parent, target, target_name = _get_submodules(self.model, key)
except AttributeError:
continue
if isinstance(target, LoraLayer):
if isinstance(target, nn.Embedding):
new_module = torch.nn.Embedding(target.in_features, target.out_features)
elif isinstance(target, nn.Conv2d):
new_module = torch.nn.Conv2d(
target.in_channels,
target.out_channels,
kernel_size=target.kernel_size,
stride=target.stride,
padding=target.padding,
dilation=target.dilation,
)
else:
bias = target.bias is not None
if getattr(target, "is_target_conv_1d_layer", False):
new_module = Conv1D(target.out_features, target.in_features)
else:
new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
if merge:
target.merge()
self._replace_module(parent, target_name, new_module, target)
# save any additional trainable modules part of `modules_to_save`
if isinstance(target, ModulesToSaveWrapper):
setattr(parent, target_name, target.modules_to_save[target.active_adapter])
return self.model
# peft/tuners/lora.py - 361 line
def _replace_module(self, parent_module, child_name, new_module, old_module):
setattr(parent_module, child_name, new_module)
new_module.weight = old_module.weight
if hasattr(old_module, "bias"):
if old_module.bias is not None:
new_module.bias = old_module.bias
if getattr(old_module, "state", None) is not None:
new_module.state = old_module.state
new_module.to(old_module.weight.device)
# dispatch to correct device
for name, module in new_module.named_modules():
if "lora_" in name:
module.to(old_module.weight.device)
if "ranknum" in name:
module.to(old_module.weight.device)