-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New LoRA Initialization Method: Explained Variance Adaptation #2142
base: main
Are you sure you want to change the base?
Conversation
Thanks a lot for this PR. I only skimmed it for now and have yet to read the details from the paper but this looks quite promising! When looking at the implementation, I saw that the way the Eva weights are initialized sits a bit "outside" of how we normally do this in PEFT. The normal way is to have this on the layer level, check for instance here: peft/src/peft/tuners/lora/layer.py Lines 128 to 139 in 5758a7e
My guess why this is implemented differently here is that there is currently one pass through the modules that determines the individual ranks and creates the state dict. If this was refactored such that the weights were initialized on the layer level as usual, we would need two passes, one to adjust the ranks in the config and then the actual weight initialization on the layer level. Would it be possible to refactor the code to enable that? My main concern with the current approach is that the weight initialization currently is happening in one place of the code but now there would be another place, making the code more disorganized. Second, IIUC, we would initialize the weights once and then override them with Eva parameters, which is wasteful. Finally, if we do this on the layer level, that could solve the multi-GPU issue, right? Apart from this, what I noticed while skimming is that there is a big class I also had difficulties understanding the Regarding Going forward, we'll also have to add a few unit tests, but I'm fine with iterating on the implementation first and then adding the tests. |
Thanks a lot for the quick feedback @BenjaminBossan ! I added docstrings to Regarding your other concerns about the implementation:
|
Thanks for the further explanations. Let me discuss the points separately. InitializationI see the issue with the cycling through the data. This is indeed tricky but I think it's worth putting some time into considering alternative solutions. Right now, the inputs = {k: v.to(device) for k, v in next(iter(dataloader)).items() if k != "labels"}
model(**inputs) Let me throw out an alternative: Let's say that we had a function called # normal PEFT code
base_model = ...
lora_config = LoraConfig(...)
model = get_peft_model(base_model, lora_config)
# new eva code
eva_config = EvaConfig(...)
sample_data = ... # user needs to define this instead of relying on data loader
def forward_fn(model):
return model(**sample_data)
model = initialize_lora_eva_weights(model, config=eva_config, forward_fn=forward_fn, adapter_name="default") This would make it so that Eva is a function that is applied to the LoRA model after it was initialized. This avoids the issues I mentioned above. The disadvantages I see with this approach:
Do you think this would work with the Eva algorithm or is there something I'm missing? LMK what you think. IncrementalPCAI wonder if we could use some of the sklearn unit tests and add them here to ensure that the Of course, this is all a bit outside of scope for PEFT, but if there is no external package that provides the functionality, it would be okay to add it to PEFT. |
ImplementationI was wondering already if having the dataloader in the confige is a good idea since you probably only want primitive types there. I think your suggested solution is a sensible one. I gave it some thought and dont really have a better suggestion following the restrictions we discussed. A few questions for clarification:
try:
states = input.detach()
except AttributeError:
states = input[0].detach() So for cases where we input is a tuple, we just take the first value. Not super important at this stage I guess but I just wanted to ask if you think we should add an argument so the user can define what the inputs we should use are. IncrementalPCAIt's a good idea to take some from sklearn. All tests for this class are in this file. Right now there is nothing like this available anywhere. I built my own package torch-incremental-pca but I just assumed you dont want any dependencies to small packages like this so I just copied the class over here. I did however create a pull request to have this feature implemented in pytorch directly so in the future we might be able to remove this from peft again. |
Okay, then let's go with this general approach. Maybe we have some ideas for improvements along the way, then we can still pivot.
Yes, let's keep it as an option. By itself, it wouldn't do anything, but it helps with discoverability. We could also consider setting some kind of flag and when users want to use the model with Eva init but before calling the function we discussed, the flag could be checked and raise an error/warning. Moreover, with this init option being set, we can signal to other packages that they need to call the function (e.g. trl could automatically call this in their
Exactly, I don't think we absolutely need a dataloader, do we? AFAICT, we only need it to get a batch of data. If it needs to be a different batch each time, we could do: iter_dl = iter(cycle(dataloader)) # optionally tqdm
def forward_fn(model):
sample_data = next(iter_dl)
return model(**sample_data)
My reasoning for this is that there are too many possibilities how the model could be called, we can't cover all of them. The assumption that we can just filter out labels and then call But thinking more about it, I'd say what we can do is to keep your existing API of passing the dataloader and give the option for a custom
I'd definitely rather have an
That's correct. We could make it an optional dependency, as it is only needed for Eva. But we could also copy the code over to PEFT ("vendoring"), which is similar to what we did with
Nice. Of course, even if it lands, we can't use it immediately as we want to support older PyTorch versions. But definitely add a comment with a link to the PR. I would imagine, however, that PyTorch won't be too keen on the sklearn API, but let's see. |
@BenjaminBossan before going ahead with this implementation of calling a function/method after Instead of using class EvaModel(LoraModel):
def __init__(self, model, config, adapter_name)
self.model = model
for layer in config.target_modules:
module = get_module(self.model, layer)
module.register_forward_hooks(SVDHook(config)) During the first few forward passes of finetuning we would not actually update any weights but use the activations to compute the eva state dict. Once SVD is converged for all layers we would switch the internal state of the model to finetuning (with some boolean variable) and at that point only add the adapters with the appropriate eva values for the lora A matrices. Hope this makes sense. That way we could leverage the finetuning loop directly. It seem a bit more seamless and works without needing any additional function calls after |
I like the proposal to seamlessly integrate this into the training loop. I'm not sure if we need a separate class for that, but you could come up with a prototype and I can check if there is a good way to integrate it without a separate class. With this, we would need to make sure that it works with vanilla training loops but also frameworks like |
@BenjaminBossan I started implementing this approach of using the first few batches for the initialization and unfortunately I realised that there are at least 2 major issues with it:
These issues makes me wonder if some variant if my initially proposed solution would be feasible (computing the eva state dict before initializing the adapters) One idea I would have which avoids passing a dataloader or something similar through the config would be to enable the def get_peft_model(
model: PreTrainedModel,
peft_config: PeftConfig,
adapter_name: str = "default",
mixed: bool = False,
autocast_adapter_dtype: bool = True,
revision: Optional[str] = None,
**model_init_kwargs
):
...
PeftModel(model, peft_config, adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype, **model_init_kwargs)
...
peft_model = get_peft_model(model, peft_config, eva_dataloader=dataloader, eva_forward_fn=eva_forward_fn) We also don't directly have to modify the Let me know what you think about the issues I mentioned with the previously proposed approaches. If my suggested approach is not feasible, I am fine with going ahead with the implementation suggested by you with a |
Good points. There would probably ways to work around this or accept these shortfalls, but I agree that this is not ideal.
I'm a bit reluctant to add more options to To avoid initializing the LoRA weights through |
Description
In our work "One Initialization to Rule them All: Fine-tuning via Explained Variance Adaptation" (paper) we introduce a new data-driven initialization method for LoRA adapters. We will present this work at ENLSP (oral) and AFM at NeurIPS'24.
Eva initializes LoRA in a data-driven manner based on information of the downstream data using SVD. It also adaptively allocates ranks throughout the model.
Implementation Details
EvaConfig
which at minimum needs to receive a value for thedataloader
argument. Other parameters have default values which worked well in our experiments. The instantiated EvaConfig is passed as a parameter to LoraConfig. (we took inspiration from the LoftQConfig implementation).In LoraConfig "init_lora_weights" needs to be set to "eva"
__init__
method ofLoraModel
to create a state dict for eva initialzation before initializing the parent class to be able to populate therank_pattern
argument inLoraConfig
. This way we directly initialize with a modified rank distribution allowing us to directly load the eva_state_dict.Open Questions
self.model.load_state_dict(eva_state_dict, strict=False)
is always a valid way to initialize. E.g. when the model is wrapped with torch.compileExample Usage