Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New LoRA Initialization Method: Explained Variance Adaptation #2142

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

sirluk
Copy link

@sirluk sirluk commented Oct 10, 2024

Description

In our work "One Initialization to Rule them All: Fine-tuning via Explained Variance Adaptation" (paper) we introduce a new data-driven initialization method for LoRA adapters. We will present this work at ENLSP (oral) and AFM at NeurIPS'24.

Eva initializes LoRA in a data-driven manner based on information of the downstream data using SVD. It also adaptively allocates ranks throughout the model.

Implementation Details

  1. As the initialization is data-driven we created a new config class EvaConfig which at minimum needs to receive a value for the dataloader argument. Other parameters have default values which worked well in our experiments. The instantiated EvaConfig is passed as a parameter to LoraConfig. (we took inspiration from the LoftQConfig implementation).
    In LoraConfig "init_lora_weights" needs to be set to "eva"
  2. We modify the __init__ method of LoraModel to create a state dict for eva initialzation before initializing the parent class to be able to populate the rank_patternargument in LoraConfig. This way we directly initialize with a modified rank distribution allowing us to directly load the eva_state_dict.
  3. All other code necessary for the initialzation is self-contained in peft/tuners/lora/eva.py
  4. Under examples/eva_finetuning/eva_finetuning.py we add an example script for eva initialization

Open Questions

  • Not sure if self.model.load_state_dict(eva_state_dict, strict=False) is always a valid way to initialize. E.g. when the model is wrapped with torch.compile
  • We currently only tested a single GPU setting. We are not sure what the best way is to have multi-gpu support (only run init on rank 0 and copy to other ranks when done?)

Example Usage

eva_config = EvaConfig(
    dataloader = dataloader
)
peft_config = LoraConfig(
    r = 16,
    lora_alpha = 1,
    target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
    init_lora_weights = "eva",
    eva_config = eva_config,
)
peft_model = get_peft_model(model, peft_config)

@BenjaminBossan
Copy link
Member

Thanks a lot for this PR. I only skimmed it for now and have yet to read the details from the paper but this looks quite promising!

When looking at the implementation, I saw that the way the Eva weights are initialized sits a bit "outside" of how we normally do this in PEFT. The normal way is to have this on the layer level, check for instance here:

# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
with gather_params_ctx(self.get_base_layer().weight):
self.pissa_init(adapter_name, init_lora_weights)
elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora":
with gather_params_ctx(self.get_base_layer().weight):
self.olora_init(adapter_name)
elif init_lora_weights == "loftq":
with gather_params_ctx(self.get_base_layer().weight):
self.loftq_init(adapter_name)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)

My guess why this is implemented differently here is that there is currently one pass through the modules that determines the individual ranks and creates the state dict. If this was refactored such that the weights were initialized on the layer level as usual, we would need two passes, one to adjust the ranks in the config and then the actual weight initialization on the layer level. Would it be possible to refactor the code to enable that?

My main concern with the current approach is that the weight initialization currently is happening in one place of the code but now there would be another place, making the code more disorganized. Second, IIUC, we would initialize the weights once and then override them with Eva parameters, which is wasteful. Finally, if we do this on the layer level, that could solve the multi-GPU issue, right?

Apart from this, what I noticed while skimming is that there is a big class IncrementalPCA which looks like a rewrite of the equivalent sklearn class using torch. My question is: Did you implement this from scratch or is it taken from somewhere? This has implications on how well this is tested etc.

I also had difficulties understanding the SVDHook and HashHook, could you please add a small docstring? Are these hooks permanently added?

Regarding torch.compile, we could add a test for that and see if it works. But this is not a necessity for this PR to land in PEFT, just something that's nice to have.

Going forward, we'll also have to add a few unit tests, but I'm fine with iterating on the implementation first and then adding the tests.

@sirluk
Copy link
Author

sirluk commented Oct 10, 2024

Thanks a lot for the quick feedback @BenjaminBossan !

I added docstrings to SVDHookand HashHook, hope this makes their purpose a bit more clear.

Regarding your other concerns about the implementation:

  • I agree its not that our proposed init method is not on a layer level as the other inits. I thought a lot about where in the code to put the init and unfortunately I think it would be quite difficult to put the init directly in the layer. (or at least I haven't found an elegant way to do it). The way eva is initialized is that we do multiple forward passes through the model, and incrementally perform SVD on the input activations of the target layers (usually more than 10 forward passes). The resulting singular vectors are used to initialize the Lora A matrices. So unfortunately 2 passes would also not be enough. We would also need to somehow pass an instance of EvaConfig to the LoraLayer.
    What we could maybe do is create the eva_state_dict as it is created now and then somehow pass the state dict to LoraModel._create_and_replace and then the relevant tensors to the update_layer methods of individual layers.
    I did also add a statement in the update_layer method to intialize the Lora B matrices with zeros as needed for eva. For this reason we do not initialize weights twice as the A matrices are not initialized in the update_layer method.

  • Regarding the new eva code. IncrementalPCA is indeed largely take from the sklearn class. So a lot of the code is directly copied from there. What we added is seamless gpu support by setting the device for all internal variables / buffers to the device of the first input tensor. We also integrated support for torch.svd_lowrank to significantly speed up computations. We did several tests internally if the results are equivalent when running this on gpu compared to running the sklearn implementation on cpu.

@BenjaminBossan
Copy link
Member

Thanks for the further explanations. Let me discuss the points separately.

Initialization

I see the issue with the cycling through the data. This is indeed tricky but I think it's worth putting some time into considering alternative solutions.

Right now, the dataloader is added to the config. We should really try to avoid this. The config should be a small, json-serializable object. This means it can't really contain any reference to the data. Moreover, there are so many unknowns about the data and how the model is called that I don't think we should have code like:

    inputs = {k: v.to(device) for k, v in next(iter(dataloader)).items() if k != "labels"}
    model(**inputs)

Let me throw out an alternative: Let's say that we had a function called initialize_lora_eva_weights or something like that. The API would be:

# normal PEFT code
base_model = ...
lora_config = LoraConfig(...)
model = get_peft_model(base_model, lora_config)

# new eva code
eva_config = EvaConfig(...)
sample_data = ...  # user needs to define this instead of relying on data loader

def forward_fn(model):
    return model(**sample_data)

model = initialize_lora_eva_weights(model, config=eva_config, forward_fn=forward_fn, adapter_name="default")

This would make it so that Eva is a function that is applied to the LoRA model after it was initialized. This avoids the issues I mentioned above.

The disadvantages I see with this approach:

  1. Eva initialization is a separate step. Therefore, packages that want to support it have to explicitly account for it.
  2. We would initialize the weights twice (but it's not really worse than right now).

Do you think this would work with the Eva algorithm or is there something I'm missing? LMK what you think.

IncrementalPCA

I wonder if we could use some of the sklearn unit tests and add them here to ensure that the IncrementalPCA produces the correct results. That way, we have the safety of knowing it works as expected without too much effort of adding our own unit tests.

Of course, this is all a bit outside of scope for PEFT, but if there is no external package that provides the functionality, it would be okay to add it to PEFT.

@sirluk
Copy link
Author

sirluk commented Oct 11, 2024

Implementation

I was wondering already if having the dataloader in the confige is a good idea since you probably only want primitive types there. I think your suggested solution is a sensible one. I gave it some thought and dont really have a better suggestion following the restrictions we discussed.

A few questions for clarification:

  • Would eva still be a valid value for init_lora_weights in LoraConfig. I am a bit worried about discoverability of eva if LoraConfig would then just be normal lora settings and someone needs to know about the function initialize_lora_eva_weights to use eva.
  • In your suggested solution, would sample data not be a dataloader?
  • Is it necessary to define a custom forward function for the eva init? Usually it should be anyway the same inputs that the model receives during normal finetuning. I did adapt the line you mentioned filtering out the label key from the inputs dict.
  • One question about how we get the layer activations for SVD. In the forward hook SVDHook we currently have it implemented like this
try:
    states = input.detach()
except AttributeError:
    states = input[0].detach()

So for cases where we input is a tuple, we just take the first value. Not super important at this stage I guess but I just wanted to ask if you think we should add an argument so the user can define what the inputs we should use are.

IncrementalPCA

It's a good idea to take some from sklearn. All tests for this class are in this file. Right now there is nothing like this available anywhere. I built my own package torch-incremental-pca but I just assumed you dont want any dependencies to small packages like this so I just copied the class over here. I did however create a pull request to have this feature implemented in pytorch directly so in the future we might be able to remove this from peft again.

@BenjaminBossan
Copy link
Member

I was wondering already if having the dataloader in the confige is a good idea since you probably only want primitive types there. I think your suggested solution is a sensible one. I gave it some thought and dont really have a better suggestion following the restrictions we discussed.

Okay, then let's go with this general approach. Maybe we have some ideas for improvements along the way, then we can still pivot.

  • Would eva still be a valid value for init_lora_weights in LoraConfig. I am a bit worried about discoverability of eva if LoraConfig would then just be normal lora settings and someone needs to know about the function initialize_lora_eva_weights to use eva.

Yes, let's keep it as an option. By itself, it wouldn't do anything, but it helps with discoverability. We could also consider setting some kind of flag and when users want to use the model with Eva init but before calling the function we discussed, the flag could be checked and raise an error/warning. Moreover, with this init option being set, we can signal to other packages that they need to call the function (e.g. trl could automatically call this in their SFTTrainer).

  • In your suggested solution, would sample data not be a dataloader?

Exactly, I don't think we absolutely need a dataloader, do we? AFAICT, we only need it to get a batch of data. If it needs to be a different batch each time, we could do:

iter_dl = iter(cycle(dataloader))  # optionally tqdm

def forward_fn(model):
    sample_data = next(iter_dl)
    return model(**sample_data)

Is it necessary to define a custom forward function for the eva init? Usually it should be anyway the same inputs that the model receives during normal finetuning. I did adapt the line you mentioned filtering out the label key from the inputs dict.

My reasoning for this is that there are too many possibilities how the model could be called, we can't cover all of them. The assumption that we can just filter out labels and then call model(**inputs) might mostly work for standard HF LLM architectures, but this method should also work with other model types, right? Let's keep the possibilities open.

But thinking more about it, I'd say what we can do is to keep your existing API of passing the dataloader and give the option for a custom forward_fn on top. That way, there is a more convenient way for the standard models and a more flexible and advanced way for all the others.

  • One question about how we get the layer activations for SVD.

I'd definitely rather have an instance check than a try ... except for this. But I'm not sure if that's what you were asking.

I just assumed you dont want any dependencies to small packages like this so I just copied the class over here.

That's correct. We could make it an optional dependency, as it is only needed for Eva. But we could also copy the code over to PEFT ("vendoring"), which is similar to what we did with BufferDict. What I would like to see: If you copy some of the sklearn tests to your package and ensure that they pass, we wouldn't need them in PEFT too, we could just say: "check this URL for the original code". This way, there is a bit less clutter.

I did however create a pull request to have this feature implemented in pytorch directly so in the future we might be able to remove this from peft again.

Nice. Of course, even if it lands, we can't use it immediately as we want to support older PyTorch versions. But definitely add a comment with a link to the PR. I would imagine, however, that PyTorch won't be too keen on the sklearn API, but let's see.

@sirluk
Copy link
Author

sirluk commented Oct 11, 2024

@BenjaminBossan before going ahead with this implementation of calling a function/method after get_peft_model is called I just wanted to propose one alternative idea to get your thoughts on it.

Instead of using LoraModel we create a subclass EvaModel. This model class initially would not have any adapter layers but we would add the forward hooks needed for the SVD computations. In the __init__ we would have something like

class EvaModel(LoraModel):

	def __init__(self, model, config, adapter_name)
		self.model = model
		for layer in config.target_modules:
			module = get_module(self.model, layer)
			module.register_forward_hooks(SVDHook(config))

During the first few forward passes of finetuning we would not actually update any weights but use the activations to compute the eva state dict. Once SVD is converged for all layers we would switch the internal state of the model to finetuning (with some boolean variable) and at that point only add the adapters with the appropriate eva values for the lora A matrices.

Hope this makes sense. That way we could leverage the finetuning loop directly. It seem a bit more seamless and works without needing any additional function calls after get_peft_model. But of course there might be some hidden issues I am not considering at the moment.

@BenjaminBossan
Copy link
Member

I like the proposal to seamlessly integrate this into the training loop. I'm not sure if we need a separate class for that, but you could come up with a prototype and I can check if there is a good way to integrate it without a separate class.

With this, we would need to make sure that it works with vanilla training loops but also frameworks like Trainer from transformers. If it does, I think this could be the optimal solution.

@sirluk
Copy link
Author

sirluk commented Oct 13, 2024

@BenjaminBossan I started implementing this approach of using the first few batches for the initialization and unfortunately I realised that there are at least 2 major issues with it:

  • The LoRA layers must be initialized at the beginning of training so that they can be passed to the optimizer. This is especially important if it needs to work out of the box with the HuggingFace Trainer. However, this means that the LoRA weights would take up unnecessary GPU memory that would be needed for the SVD computation. This is btw also an issue with the solution you proposed of doing a normal LoRA init first and then calling a function to initalize eva.
  • Further, if you use a learning rate schedule, it would already adjust the learning rate while you are still executing the SVD steps. This means that when you start fine-tuning, the learning rate is different from the intended initial learning rate. Especially problematic if there is a warmup I would assume.

These issues makes me wonder if some variant if my initially proposed solution would be feasible (computing the eva state dict before initializing the adapters)

One idea I would have which avoids passing a dataloader or something similar through the config would be to enable the get_peft_modelfunction to accept some optional keyword arguments through **kwargs and passing the dataloader as well as a model forward_fn this way.

def get_peft_model(
	model: PreTrainedModel,
	peft_config: PeftConfig,
	adapter_name: str = "default",
	mixed: bool = False,
	autocast_adapter_dtype: bool = True,
	revision: Optional[str] = None,
	**model_init_kwargs
):
   ...
   PeftModel(model, peft_config, adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype, **model_init_kwargs)
   ...



peft_model = get_peft_model(model, peft_config, eva_dataloader=dataloader, eva_forward_fn=eva_forward_fn)

We also don't directly have to modify the __init__ method of LoraModel but define a _pre_injection_hook for LoraModel which should also do the trick.

Let me know what you think about the issues I mentioned with the previously proposed approaches. If my suggested approach is not feasible, I am fine with going ahead with the implementation suggested by you with a initialize_lora_eva_weights function (this will also unfortunately suffer from the downside mentioned above that already initialized lora weights will take up gpu memory that could potentially be useful for the svd computation, which we might want to run with a higher batch size than finetuning).

@BenjaminBossan
Copy link
Member

I started implementing this approach of using the first few batches for the initialization and unfortunately I realised that there are at least 2 major issues with it:

Good points. There would probably ways to work around this or accept these shortfalls, but I agree that this is not ideal.

One idea I would have which avoids passing a dataloader or something similar through the config would be to enable the get_peft_modelfunction to accept some optional keyword arguments through **kwargs and passing the dataloader as well as a model forward_fn this way.

I'm a bit reluctant to add more options to get_peft_model, especially when they're only used in a very narrow use case (e.g. this option only works for LoRA but get_peft_model is the entry point for all PEFT methods). Therefore, I tend to prefer the option of having a dedicated function.

To avoid initializing the LoRA weights through get_peft_model, I have a proposal. We recently added an option low_cpu_mem_usage=True to PEFT, which will results in the PEFT weights being initialized on meta device. This option has not been propagated to get_peft_model because until now, there was no reason to have it there. However, with Eva, it would be useful. So we could add the option there and then just pass it along to PeftModel in this line, which already accepts that argument. That way, we avoid initializing the LoRA weights. WDYT about that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants