Skip to content

Commit

Permalink
Add model folder automatically, and update dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanScripts committed Sep 26, 2024
1 parent 5714cee commit ed67d9e
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 47 deletions.
19 changes: 11 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,28 @@ Along with some utility nodes for working with text:
- Join String
- Select Index
- Slice List

Install the latest version of transformers, which has support for Pixtral/Llama Vision models:
`python_embeded\python.exe -m pip install git+https:/huggingface/transformers`

Requires transformers 4.45.0 for Pixtral and 4.46.0 for Llama Vision.
## Installation

Available in [ComfyUI-Manager](https:/ltdrdata/ComfyUI-Manager) as ComfyUI-PixtralLlamaVision. When installed from ComfyUI-Manager, the required packages will be installed automatically.

If you install by cloning this repo into your custom nodes folder, you'll need to install `transformers >= 4.45.0` to load Pixtral and Llama Vision models:
`python_embeded\python.exe -m pip install transformers --upgrade`

Also install bitsandbytes if you don't have it already:
`python_embeded\python.exe -m pip install bitsandbytes`

Models should be placed in the `ComfyUI/models/pixtral` and `ComfyUI/models/llama-vision` folders, with each model inside a folder with the `model.safetensors` file along with any config files and the tokenizer.

You can get a 4-bit quantized version of Pixtral-12B which is compatible with these custom nodes here: [https://huggingface.co/SeanScripts/pixtral-12b-nf4](https://huggingface.co/SeanScripts/pixtral-12b-nf4)
You can get a 4-bit quantized version of Pixtral-12B and/or Llama-3.2-11B-Vision-Instruct which is compatible with these custom nodes here:

[https://huggingface.co/SeanScripts/pixtral-12b-nf4](https://huggingface.co/SeanScripts/pixtral-12b-nf4)

You can get a 4-bit quantized version of Llama-3.2-11B-Vision-Instruct which is compatible with these custom nodes here:
[https://huggingface.co/SeanScripts/Llama-3.2-11B-Vision-Instruct-nf4](https://huggingface.co/SeanScripts/Llama-3.2-11B-Vision-Instruct-nf4)

Unfortunately, the Pixtral nf4 model has considerably degraded performance on some tasks, like OCR. The Llama Vision model seems to be better for this task.

I also tested Pixtral's object detection with bounding box generation and it seems to sort of work with approximate results, though it still fails quite often. The full model might be better, or maybe it could be improved with finetuning.
## Examples

Example Pixtral image captioning (not saving the output to a text file in this example):
![Example Pixtral image captioning workflow](pixtral_caption_example.jpg)
Expand All @@ -51,4 +54,4 @@ Since Pixtral directly tokenizes the input images, it's able to handle them inli
Example Llama Vision object detection with bounding box:
![Example Llama Vision object detection with bounding box workflow](llama_vision_bounding_box_example.jpg)

Both models kind of work for this, but not that well. They definitely have some understanding of the positions of objects in the image, though. Maybe it needs a better prompt. Or a non-quantized model. Or a finetune. But it does sometimes work.
Both models kind of work for this, but not that well. They definitely have some understanding of the positions of objects in the image, though. Maybe it needs a better prompt. Or a non-quantized model. Or a finetune. But it does sometimes work.
59 changes: 21 additions & 38 deletions nodes.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,30 @@
import comfy.utils
import comfy.model_management as mm
import folder_paths

from transformers import AutoProcessor, BitsAndBytesConfig, set_seed

pixtral = True
llama_vision = True
# transformers 4.45.0
try:
from transformers import LlavaForConditionalGeneration
except ImportError:
print("[ComfyUI-PixtralLlamaVision] Can't load Pixtral, need to update transformers")
pixtral = False

# transformers 4.46.0
try:
from transformers import MllamaForConditionalGeneration
except ImportError:
print("[ComfyUI-PixtralLlamaVision] Can't load Llama Vision, need to update transformers")
llama_vision = False

# Requires transformers >= 4.45.0
from transformers import LlavaForConditionalGeneration, MllamaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig, set_seed
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import time
import os
from pathlib import Path
import re

pixtral_model_dir = os.path.join(folder_paths.models_dir, "pixtral")
llama_vision_model_dir = os.path.join(folder_paths.models_dir, "llama-vision")
# Add pixtral and llama-vision folders if not present
if not os.path.exists(pixtral_model_dir):
os.makedirs(pixtral_model_dir)
if not os.path.exists(llama_vision_model_dir):
os.makedirs(llama_vision_model_dir)

class PixtralModelLoader:
"""Loads a Pixtral model. Add models as folders inside the `ComfyUI/models/pixtral` folder. Each model folder should contain a standard transformers loadable safetensors model along with a tokenizer and any config files needed."""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": ([item.name for item in Path(folder_paths.models_dir, "pixtral").iterdir() if item.is_dir()],),
"model_name": ([item.name for item in Path(pixtral_model_dir).iterdir() if item.is_dir()],),
}
}

Expand All @@ -43,7 +34,7 @@ def INPUT_TYPES(s):
TITLE = "Load Pixtral Model"

def load_model(self, model_name):
model_path = os.path.join(folder_paths.models_dir, "pixtral", model_name)
model_path = os.path.join(pixtral_model_dir, model_name)
device = mm.get_torch_device()
model = LlavaForConditionalGeneration.from_pretrained(
model_path,
Expand Down Expand Up @@ -115,7 +106,7 @@ class LlamaVisionModelLoader:
def INPUT_TYPES(s):
return {
"required": {
"model_name": ([item.name for item in Path(folder_paths.models_dir, "llama-vision").iterdir() if item.is_dir()],),
"model_name": ([item.name for item in Path(llama_vision_model_dir).iterdir() if item.is_dir()],),
}
}

Expand All @@ -125,7 +116,7 @@ def INPUT_TYPES(s):
TITLE = "Load Llama Vision Model"

def load_model(self, model_name):
model_path = os.path.join(folder_paths.models_dir, "llama-vision", model_name)
model_path = os.path.join(llama_vision_model_dir, model_name)
device = mm.get_torch_device()
model = MllamaForConditionalGeneration.from_pretrained(
model_path,
Expand Down Expand Up @@ -433,6 +424,13 @@ def select_index(self, list, start_index, end_index):
# Batch Count works for getting list length

NODE_CLASS_MAPPINGS = {
"PixtralModelLoader": PixtralModelLoader,
"PixtralGenerateText": PixtralGenerateText,
# Not really much need to work with the image tokenization directly for something like image captioning, but might be interesting later...
#"PixtralImageEncode": PixtralImageEncode,
#"PixtralTextEncode": PixtralTextEncode,
"LlamaVisionModelLoader": LlamaVisionModelLoader,
"LlamaVisionGenerateText": LlamaVisionGenerateText,
"RegexSplitString": RegexSplitString,
"RegexSearch": RegexSearch,
"RegexFindAll": RegexFindAll,
Expand All @@ -443,19 +441,4 @@ def select_index(self, list, start_index, end_index):
"SliceList": SliceList,
}

if pixtral:
NODE_CLASS_MAPPINGS |= {
"PixtralModelLoader": PixtralModelLoader,
"PixtralGenerateText": PixtralGenerateText,
# Not really much need to work with the image tokenization directly for something like image captioning, but might be interesting later...
#"PixtralImageEncode": PixtralImageEncode,
#"PixtralTextEncode": PixtralTextEncode,
}

if llama_vision:
NODE_CLASS_MAPPINGS |= {
"LlamaVisionModelLoader": LlamaVisionModelLoader,
"LlamaVisionGenerateText": LlamaVisionGenerateText,
}

NODE_DISPLAY_NAME_MAPPINGS = {k:v.TITLE for k,v in NODE_CLASS_MAPPINGS.items()}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-pixtralllamavision"
description = "For loading and running Pixtral and Llama 3.2 Vision models"
version = "2.1.0"
version = "2.1.1"
license = {file = "LICENSE"}

[project.urls]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
transformers >= 4.45.0
accelerate
bitsandbytes

0 comments on commit ed67d9e

Please sign in to comment.