diff --git a/.gitignore b/.gitignore index a360c0f8..0d3bac98 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ debug .vscode/ logs .DS_Store +imgui.ini +playpen/__repo_level_awareness # --- Python --- diff --git a/kai/models/report_types.py b/kai/models/report_types.py index 9385872a..80cc0015 100644 --- a/kai/models/report_types.py +++ b/kai/models/report_types.py @@ -58,8 +58,11 @@ class ExtendedIncident(Incident): ruleset_name: str ruleset_description: Optional[str] = None + violation_name: str violation_description: Optional[str] = None + violation_category: Category = Category.POTENTIAL + violation_labels: list[str] = [] class Link(BaseModel): diff --git a/playpen/client/__init__.py b/playpen/middleman/__init__.py similarity index 100% rename from playpen/client/__init__.py rename to playpen/middleman/__init__.py diff --git a/playpen/middleman/fake_gui_sdl2.py b/playpen/middleman/fake_gui_sdl2.py new file mode 100644 index 00000000..f8d976e8 --- /dev/null +++ b/playpen/middleman/fake_gui_sdl2.py @@ -0,0 +1,1080 @@ +""" +NOTE(JonahSussman): This is probably some of the ugliest code that I have ever +written in my entire life. I am so sorry. + +The main goal with this code is to create a fake GUI that can be used to +interact with the RPC server. This is useful for testing purposes, as it allows +us to rapidly iterate on the RPC server without needing to interface with the +IDE. +""" + +import ctypes +import json +import os +import subprocess # trunk-ignore(bandit/B404) +import sys +import threading +from abc import ABC, abstractmethod +from enum import Enum +from io import BufferedReader, BufferedWriter +from pathlib import Path +from types import NoneType +from typing import IO, Any, TypeVar, Union, cast, get_args, get_origin +from urllib.parse import urlparse + +import imgui # type: ignore[import-untyped] +import OpenGL.GL as gl # type: ignore[import-untyped] +import yaml +from imgui.integrations.sdl2 import SDL2Renderer # type: ignore[import-untyped] +from pydantic import BaseModel, ValidationError +from sdl2 import ( # type: ignore[import-untyped] + SDL_GL_ACCELERATED_VISUAL, + SDL_GL_CONTEXT_FLAGS, + SDL_GL_CONTEXT_FORWARD_COMPATIBLE_FLAG, + SDL_GL_CONTEXT_MAJOR_VERSION, + SDL_GL_CONTEXT_MINOR_VERSION, + SDL_GL_CONTEXT_PROFILE_CORE, + SDL_GL_CONTEXT_PROFILE_MASK, + SDL_GL_DEPTH_SIZE, + SDL_GL_DOUBLEBUFFER, + SDL_GL_MULTISAMPLEBUFFERS, + SDL_GL_MULTISAMPLESAMPLES, + SDL_GL_STENCIL_SIZE, + SDL_HINT_MAC_CTRL_CLICK_EMULATE_RIGHT_CLICK, + SDL_HINT_VIDEO_HIGHDPI_DISABLED, + SDL_INIT_EVERYTHING, + SDL_QUIT, + SDL_WINDOW_OPENGL, + SDL_WINDOW_RESIZABLE, + SDL_WINDOWPOS_CENTERED, + SDL_CreateWindow, + SDL_DestroyWindow, + SDL_Event, + SDL_GetError, + SDL_GL_CreateContext, + SDL_GL_DeleteContext, + SDL_GL_MakeCurrent, + SDL_GL_SetAttribute, + SDL_GL_SetSwapInterval, + SDL_GL_SwapWindow, + SDL_Init, + SDL_PollEvent, + SDL_Quit, + SDL_SetHint, +) + +from kai.models.kai_config import KaiConfigModels +from kai.models.report import Report +from kai.models.report_types import Category, ExtendedIncident +from kai.models.util import remove_known_prefixes +from playpen.middleman.server import ( + GetCodeplanAgentSolutionParams, + GitVFSUpdateParams, + KaiRpcApplicationConfig, + TestRCMParams, +) +from playpen.rpc.core import JsonRpcApplication, JsonRpcServer +from playpen.rpc.models import JsonRpcId +from playpen.rpc.streams import BareJsonStream + +BaseModelT = TypeVar("BaseModelT", bound=BaseModel) + + +def try_construct_base_model(cls: type[BaseModelT]) -> BaseModelT: + inp: dict[Any, Any] = {} + + def set_field(d: dict[Any, Any], loc: tuple[int | str, ...], value: Any) -> None: + if len(loc) == 1: + d[loc[0]] = value + else: + d[loc[0]] = {} + set_field(d[loc[0]], loc[1:], value) + + while True: + try: + obj = cls.model_validate(inp) + return obj + except ValidationError as validation_error: + for err in validation_error.errors(): + if err["type"] == "missing": + set_field(inp, err["loc"], None) + elif err["type"] == "string_type": + set_field(inp, err["loc"], "") + else: + print(err) + exit(1) + + +THIS_FILE_PATH = Path(os.path.abspath(__file__)).resolve() +THIS_DIR_PATH = THIS_FILE_PATH.parent +KAI_DIR = THIS_DIR_PATH.parent.parent + + +class Drawable(ABC): + def __init__(self, *, show: bool = True) -> None: + self.show = show + + def draw(self) -> None: + if self.show: + self._draw() + + @abstractmethod + def _draw(self) -> None: ... + + +CONFIG = KaiRpcApplicationConfig( + process_id=os.getpid(), + root_path=KAI_DIR / "example/coolstore", + analyzer_lsp_lsp_path=Path( + "/home/jonah/.vscode/extensions/redhat.java-1.35.1-linux-x64/server/bin/jdtls" + ), + analyzer_lsp_rpc_path=KAI_DIR / "analyzer-lsp", + analyzer_lsp_rules_path=Path( + "/home/jonah/Projects/github.com/konveyor/rulesets/default/generated" + ), + analyzer_lsp_java_bundle_path=Path( + "/home/jonah/Projects/github.com/konveyor-ecosystem/kai-jonah/notebooks/kai-analyzer-code-plan/java-bundle/java-analyzer-bundle.core-1.0.0-SNAPSHOT.jar" + ), + model_provider=KaiConfigModels( + provider="ChatIBMGenAI", + args={ + "model_id": "meta-llama/llama-3-70b-instruct", + }, + ), + kai_backend_url="http://localhost:8080", + log_level="TRACE", +) + + +class ConfigurationEditor(Drawable): + def __init__(self, method_and_params: list[tuple[str, BaseModel]]) -> None: + super().__init__() + + self.model_and_extra: dict[str, tuple[BaseModel, dict[str, Any]]] = {} + for method, params in method_and_params: + self.model_and_extra[method] = (params, {}) + + print(self.model_and_extra) + + def set_params( + self, method: str, params: BaseModel, extra: dict[str, Any] | None = None + ) -> None: + if extra is None: + extra = {} + + self.model_and_extra[method] = (params, extra) + + def _draw(self) -> None: + _, self.show = imgui.begin("Configuration Editor", closable=True) + + if imgui.begin_tab_bar("ConfigurationEditorTabBar"): + for method in self.model_and_extra: + if imgui.begin_tab_item(method).selected: + self.draw_tab(method) + imgui.end_tab_item() + + imgui.end_tab_bar() + + imgui.end() + + def draw_tab(self, method: str) -> None: + model, extra = self.model_and_extra[method] + + if imgui.button(f"Populate `{method}` request"): + JSON_RPC_REQUEST_WINDOW.rpc_kind_n = 0 + JSON_RPC_REQUEST_WINDOW.rpc_method = method + + try: + JSON_RPC_REQUEST_WINDOW.rpc_params = model.__class__.model_validate( + model.model_dump() + ).model_dump_json( + indent=2, + ) + except Exception as e: + JSON_RPC_REQUEST_WINDOW.rpc_params = f"Error parsing JSON: {e}" + + imgui.set_window_focus_labeled("JSON RPC Request Window") + + imgui.separator() + + model, extra = self.draw_obj( + model, + extra, + model.__class__.__name__, + model.__class__, + ) + + def draw_obj( + self, obj: Any, extra: dict[str, Any], full_name: str, cls: type + ) -> tuple[Any, dict[str, Any]]: + """ + obj is what it is, cls is what it should be + """ + origin, _args = get_origin(cls), get_args(cls) + name = full_name.split(".")[-1] + + if origin is Union: + return self.draw_union(obj, extra, full_name, cls) + + # Built-in types + elif cls is None or cls is NoneType: + return None, extra + elif cls is str or cls is Path: + return imgui.input_text(f"{name}##{full_name}", str(obj), 400)[1], extra + elif cls is int: + return imgui.input_int(f"{name}##{full_name}", obj)[1], extra + elif cls is float: + return imgui.input_float(f"{name}##{full_name}", obj)[1], extra + elif cls is bool: + return imgui.checkbox(f"{name}##{full_name}", obj)[1], extra + + elif origin is dict or cls is dict: + return self.draw_dict(obj, extra, full_name, cls) + elif origin is list or cls is list: + return self.draw_list(obj, extra, full_name, cls) + elif issubclass(cls, BaseModel): + return self.draw_base_model(obj, extra, full_name, cls) + elif issubclass(cls, Enum): + return self.draw_enum(obj, extra, full_name, cls) + else: + imgui.text(f"{full_name}: {cls} (unknown)") + return obj, extra + + def draw_enum( + self, obj: Enum | Any, extra: dict[str, Any], full_name: str, cls: type[Enum] + ) -> tuple[Enum, dict[str, Any]]: + name = full_name.split(".")[-1] + + ENUM_ITEMS_KEY = f"{full_name}.enum_items" + ENUM_CURRENT_KEY = f"{full_name}.enum_current" + + if ENUM_ITEMS_KEY not in extra: + extra[ENUM_ITEMS_KEY] = [item.name for item in cls] + if ENUM_CURRENT_KEY not in extra: + extra[ENUM_CURRENT_KEY] = next( + (i for i, item in enumerate(cls) if item == obj), + 0, + ) + + _, extra[ENUM_CURRENT_KEY] = imgui.combo( + label=f"{name}##{full_name}", + current=extra[ENUM_CURRENT_KEY], + items=extra[ENUM_ITEMS_KEY], + ) + + return cls[extra[ENUM_ITEMS_KEY][extra[ENUM_CURRENT_KEY]]], extra + + def draw_union( + self, obj: Any, extra: dict[str, Any], full_name: str, cls: type + ) -> tuple[Any, dict[str, Any]]: + args = get_args(cls) + name = full_name.split(".")[-1] + + UNION_ITEMS_KEY = f"{full_name}.union_items" + UNION_CURRENT_KEY = f"{full_name}.union_current" + + if UNION_ITEMS_KEY not in extra: + extra[UNION_ITEMS_KEY] = [arg.__name__ for arg in args] + if UNION_CURRENT_KEY not in extra: + extra[UNION_CURRENT_KEY] = next( + (i for i, arg in enumerate(args) if arg == obj.__class__), + -1, + ) + + _, extra[UNION_CURRENT_KEY] = imgui.combo( + label=f"{name}'s type##{full_name}", + current=extra[UNION_CURRENT_KEY], + items=extra[UNION_ITEMS_KEY], + ) + + return self.draw_obj( + obj, + extra, + full_name, + args[extra[UNION_CURRENT_KEY]], + ) + + def draw_base_model( + self, + obj: BaseModel | Any, + extra: dict[str, Any], + full_name: str, + cls: type[BaseModel], + ) -> tuple[BaseModel, dict[str, Any]]: + if not isinstance(obj, cls): + obj = try_construct_base_model(cls) + + imgui.text(full_name) + + imgui.indent() + + # Why is this necessary? + if not isinstance(obj.__pydantic_fields_set__, set): + obj.__pydantic_fields_set__ = set() + + for field_name, field in cls.model_fields.items(): + field_full_name = f"{full_name}.{field_name}" + if field_full_name not in extra: + extra[field_full_name] = {} + + result, extra[field_full_name] = self.draw_obj( + getattr(obj, field_name), + extra[field_full_name], + field_full_name, + field.annotation if field.annotation is not None else NoneType, + ) + + setattr(obj, field_name, result) + + imgui.unindent() + + return obj, extra + + def draw_dict( + self, + obj: dict[Any, Any] | Any, + extra: dict[str, Any], + full_name: str, + cls: type[dict[Any, Any]], + ) -> tuple[dict[Any, Any] | Exception, dict[str, Any]]: + if not isinstance(obj, dict): + obj = {} + + name = full_name.split(".")[-1] + + DICT_KEY = f"{full_name}.dict" + if DICT_KEY not in extra: + try: + extra[DICT_KEY] = json.dumps(obj, indent=2) + except Exception: + extra[DICT_KEY] = "{}" + + _, extra[DICT_KEY] = imgui.input_text_multiline( + f"{name}##{full_name}", + extra[DICT_KEY], + 4096 * 16, + ) + + try: + return json.loads(extra[DICT_KEY]), extra + except Exception as e: + return e, extra + + def draw_list( + self, + obj: list[Any] | Any, + extra: dict[str, Any], + full_name: str, + cls: type[list[Any]], + ) -> tuple[list[Any], dict[str, Any]]: + args = get_args(cls) + list_cls = str if len(args) == 0 else args[0] + + if not isinstance(obj, list): + obj = [] + + name = full_name.split(".")[-1] + + LEN_KEY = f"{full_name}.len" + if LEN_KEY not in extra: + extra[LEN_KEY] = len(obj) + + _, extra[LEN_KEY] = imgui.input_int( + f"{name} length##{full_name}", extra[LEN_KEY] + ) + + if extra[LEN_KEY] < 0: + extra[LEN_KEY] = 0 + + if extra[LEN_KEY] > len(obj): + obj.extend([None for _ in range(extra[LEN_KEY] - len(obj))]) + elif extra[LEN_KEY] < len(obj): + obj = obj[: extra[LEN_KEY]] + + for i in range(len(obj)): + ITEM_FULL_NAME_KEY = f"{full_name}.{i}" + if ITEM_FULL_NAME_KEY not in extra: + extra[ITEM_FULL_NAME_KEY] = {} + + if not issubclass(list_cls, BaseModel): + imgui.text(f"{name}.{i}") + imgui.indent() + + obj[i], extra[ITEM_FULL_NAME_KEY] = self.draw_obj( + obj[i], + extra[ITEM_FULL_NAME_KEY], + ITEM_FULL_NAME_KEY, + list_cls, + ) + + if not issubclass(list_cls, BaseModel): + imgui.unindent() + + return obj, extra + + +class SourceEditor(Drawable): + def __init__(self) -> None: + super().__init__() + + self.application_path = "/home/jonah/Projects/github.com/konveyor-ecosystem/kai-jonah/example/coolstore" + self.relative_filename = ( + "/src/main/java/com/redhat/coolstore/service/ShippingService.java" + ) + self.report_path = "/home/jonah/Projects/github.com/konveyor-ecosystem/kai-jonah/example/analysis/coolstore/output.yaml" + + self.relative_filename_path: Path | None = None + self.file_path: Path | None = None + + self.editor_content = "" + self.report: Report | None = None + self.incidents: dict[int, ExtendedIncident] = {} + + def _draw(self) -> None: + + window_name = "Source Code Editor" + if self.file_path: + window_name += f" - {self.relative_filename}" + + _, self.show = imgui.begin( + window_name, flags=imgui.WINDOW_MENU_BAR, closable=True + ) + + if imgui.begin_menu_bar(): + if imgui.begin_menu("File"): + clicked, FILE_LOADER.show = imgui.menu_item( + "Load File/Analysis Report", selected=FILE_LOADER.show + ) + imgui.end_menu() + imgui.end_menu_bar() + + for line_number, line in enumerate(self.editor_content.split("\n"), 1): + if line_number in self.incidents: + incident = self.incidents[line_number] + + # imgui.same_line() + imgui.text_colored(f"{line_number:4d}: {line}", 1.0, 0.0, 0.0) + # imgui.text_colored(f" [Issue: {highlighted_lines[idx]['message']}] ", 1.0, 0.0, 0.0) + + # Hover message + if imgui.is_item_hovered(): + imgui.begin_tooltip() + + imgui.text_colored(incident.ruleset_name, 1.0, 0.0, 0.0) + imgui.text_colored(incident.violation_name, 1.0, 0.0, 0.0) + imgui.separator() + imgui.text(incident.message) + imgui.end_tooltip() + + # Context menu + if imgui.is_item_clicked(imgui.MOUSE_BUTTON_RIGHT): + imgui.open_popup(f"context_menu_{line_number}") + + if imgui.begin_popup(f"context_menu_{line_number}"): + if imgui.begin_menu("Populate..."): + # if imgui.selectable("getRAGSolution")[0]: + # JSON_RPC_REQUEST_WINDOW.rpc_kind_n = 0 + # JSON_RPC_REQUEST_WINDOW.rpc_method = "getRAGSolution" + # JSON_RPC_REQUEST_WINDOW.rpc_params = json.dumps( + # PostGetIncidentSolutionsForFileParams( + # file_name=str(self.file_path), + # file_contents=self.editor_content, + # application_name="coolstore", + # incidents=[incident], + # ).model_dump(), + # indent=2, + # ) + + # imgui.set_window_focus_labeled("JSON RPC Request Window") + + if imgui.selectable("getCodeplanAgentSolution")[0]: + send_incident = incident.model_copy(deep=True) + send_incident.uri = send_incident.uri.replace( + "/opt/input/source/", + "/home/jonah/Projects/github.com/konveyor-ecosystem/kai-jonah/example/coolstore/", + ) + + CONFIGURATION_EDITOR.set_params( + "getCodeplanAgentSolution", + GetCodeplanAgentSolutionParams( + file_path=self.file_path or Path("/"), + incidents=[send_incident], + ), + ) + + imgui.set_window_focus_labeled("Configuration Editor") + + imgui.end_menu() + imgui.end_popup() + else: + imgui.text(f"{line_number:4d}: {line}") + + imgui.end() + + +class FileLoader(Drawable): + def __init__(self) -> None: + super().__init__() + self.show = False + + def _draw(self) -> None: + if not self.show: + return + + _, self.show = imgui.begin("File Loader", closable=True) + + imgui.text("Load File/Analysis Report") + _, SOURCE_EDITOR.application_path = imgui.input_text( + "Application Path", SOURCE_EDITOR.application_path, 400 + ) + _, SOURCE_EDITOR.relative_filename = imgui.input_text( + "Filename", SOURCE_EDITOR.relative_filename, 400 + ) + _, SOURCE_EDITOR.report_path = imgui.input_text( + "Report", SOURCE_EDITOR.report_path, 400 + ) + + if imgui.button("Load"): + application_path = Path(SOURCE_EDITOR.application_path).resolve() + parts = Path(SOURCE_EDITOR.relative_filename).resolve().parts + while len(parts) > 0 and parts[0] == "/": + parts = parts[1:] + + SOURCE_EDITOR.relative_filename_path = Path(*parts) + SOURCE_EDITOR.file_path = application_path.joinpath(*parts) + SOURCE_EDITOR.editor_content = self.load_file() + SOURCE_EDITOR.incidents = self.load_analysis_report() + + self.show = False + + imgui.end() + + def load_file(self) -> str: + if SOURCE_EDITOR.file_path is None: + return "" + + with open(SOURCE_EDITOR.file_path, "r") as file: + return file.read() + + def load_analysis_report(self) -> dict[Any, Any]: + report = Report.load_report_from_file(SOURCE_EDITOR.report_path) + result = {} + + for ruleset_name, ruleset in report.rulesets.items(): + for violation_name, violation in ruleset.violations.items(): + for incident in violation.incidents: + if report.should_we_skip_incident(incident): + continue + + file_path = Path(remove_known_prefixes(urlparse(incident.uri).path)) + if file_path != SOURCE_EDITOR.relative_filename_path: + # print(f"{SOURCE_EDITOR.relative_filename} != {file_path}") + continue + + result[incident.line_number] = ExtendedIncident( + ruleset_name=ruleset_name, + violation_name=violation_name, + ruleset_description=ruleset.description, + violation_description=violation.description, + **incident.model_dump(), + ) + + return result + + # NOTE: Break glass in case of emergency + + # return { + # 1: ExtendedIncident( + # uri=f"file://{SOURCE_EDITOR.file_path}", + # message="This is a test incident", + # code_snip="print('Hello, world!')", + # line_number=1, + # variables={}, + # ruleset_name="test ruleset", + # violation_name="test violation", + # ), + # } + + +class JsonRpcRequestWindow(Drawable): + def __init__(self) -> None: + super().__init__() + + self.rpc_kind_items = ["Request", "Notification"] + self.rpc_kind_n = 0 + self.rpc_method = "" + self.rpc_params = "" + + @property + def rpc_kind(self) -> str: + return self.rpc_kind_items[self.rpc_kind_n] + + def _draw(self) -> None: + _, self.show = imgui.begin("JSON RPC Request Window", closable=True) + + _, self.rpc_kind_n = imgui.combo( + label="Kind", current=self.rpc_kind_n, items=self.rpc_kind_items + ) + _, self.rpc_method = imgui.input_text("Method", self.rpc_method, 512) + _, self.rpc_params = imgui.input_text_multiline( + "Params", self.rpc_params, 4096 * 16 + ) + + if imgui.button("Submit"): + submit_json_rpc_request(self.rpc_kind, self.rpc_method, self.rpc_params) + + imgui.end() + + +class RequestResponseInspector(Drawable): + def __init__(self) -> None: + super().__init__() + + self.selected_indices: dict[Any, Any] = {} + + def _draw(self) -> None: + _, self.show = imgui.begin("Request/Response Inspector", closable=True) + + if imgui.begin_tab_bar("RequestsTabBar"): + + if imgui.begin_tab_item("Requests").selected: + + if imgui.begin_table( + "Requests", 3, imgui.TABLE_RESIZABLE | imgui.TABLE_SCROLL_Y + ): + imgui.table_setup_column("ID") + imgui.table_setup_column("Request") + imgui.table_setup_column("Response") + imgui.table_headers_row() + + for entry in json_rpc_responses: + imgui.table_next_column() + if imgui.selectable(str(entry["request"]["id"]))[0]: + self.selected_indices[entry["request"]["id"]] = entry + + imgui.table_next_column() + if imgui.selectable(str(entry["request"]))[0]: + self.selected_indices[entry["request"]["id"]] = entry + if imgui.is_item_hovered(): + imgui.begin_tooltip() + imgui.text(yaml.dump(entry["request"])) + imgui.end_tooltip() + + imgui.table_next_column() + + if imgui.selectable(str(entry["response"]))[0]: + self.selected_indices[entry["request"]["id"]] = entry + if imgui.is_item_hovered(): + imgui.begin_tooltip() + imgui.text(yaml.dump(entry["response"])) + imgui.end_tooltip() + + imgui.end_table() + imgui.end_tab_item() + + for idx, entry in self.selected_indices.items(): + if imgui.begin_tab_item(f"Request {idx}").selected: + imgui.text("Request") + + imgui.begin_child( + "scrollable request", + border=True, + flags=imgui.WINDOW_ALWAYS_VERTICAL_SCROLLBAR + | imgui.WINDOW_HORIZONTAL_SCROLLING_BAR, + ) + lines = yaml.dump(entry).split("\n") + for line in lines: + imgui.text(line) + imgui.end_child() + + # imgui.separator() + + # imgui.text("Response") + + # imgui.begin_child("scrollable response", border=True, flags=imgui.WINDOW_ALWAYS_VERTICAL_SCROLLBAR | imgui.WINDOW_HORIZONTAL_SCROLLING_BAR) + # imgui.text(yaml.dump(entry["response"])) + # imgui.end_child() + + imgui.end_tab_item() + + imgui.end_tab_bar() + + imgui.end() + + +class SubprocessInspector(Drawable): + def __init__(self) -> None: + super().__init__() + + self.scroll_to_bottom = False + + def _draw(self) -> None: + global rpc_subprocess_stderr_log + + _, self.show = imgui.begin("Subprocess Manager", closable=True) + + if rpc_subprocess is None: + if imgui.button("Start"): + start_server(["python", str(rpc_script_path)]) + else: + if imgui.button("Stop"): + stop_server() + + imgui.same_line() + if imgui.button("Clear log"): + rpc_subprocess_stderr_log.clear() + + imgui.text("Subprocess stderr log:") + + imgui.begin_child( + "scrollable_region", + border=True, + flags=imgui.WINDOW_ALWAYS_VERTICAL_SCROLLBAR + | imgui.WINDOW_HORIZONTAL_SCROLLING_BAR, + ) + + if self.scroll_to_bottom: + self.scroll_to_bottom = False + + for log_line in rpc_subprocess_stderr_log: + imgui.text(log_line) + imgui.end_child() + + imgui.end() + + +rpc_application: JsonRpcApplication = JsonRpcApplication() + +GIT_VFS_UPDATE_PARAMS: GitVFSUpdateParams | None = None + + +@rpc_application.add_notify(method="gitVFSUpdate") +def handle_git_vfs_update( + app: JsonRpcApplication, + server: JsonRpcServer, + id: JsonRpcId, + params: GitVFSUpdateParams, +) -> None: + global GIT_VFS_UPDATE_PARAMS + + GIT_VFS_UPDATE_PARAMS = params + + +class GitVFSInspector(Drawable): + def __init__(self) -> None: + super().__init__() + + def _draw(self) -> None: + _, self.show = imgui.begin("Git VFS Inspector", closable=True) + + if GIT_VFS_UPDATE_PARAMS is None: + imgui.text("No gitVFSUpdate notification received.") + else: + imgui.text("gitVFSUpdate notification received.") + imgui.text(yaml.dump(GIT_VFS_UPDATE_PARAMS.model_dump())) + + imgui.end() + + +# Global variables + +SOURCE_EDITOR = SourceEditor() +FILE_LOADER = FileLoader() +JSON_RPC_REQUEST_WINDOW = JsonRpcRequestWindow() +REQUEST_RESPONSE_INSPECTOR = RequestResponseInspector() +SUBPROCESS_INSPECTOR = SubprocessInspector() +GIT_VFS_INSPECTOR = GitVFSInspector() + +CONFIGURATION_EDITOR = ConfigurationEditor( + method_and_params=[ + ("initialize", CONFIG), + ( + "getCodeplanAgentSolution", + GetCodeplanAgentSolutionParams( + file_path=Path( + "/home/jonah/Projects/github.com/konveyor-ecosystem/kai-jonah/example/coolstore/src/main/java/com/redhat/coolstore/service/ShippingService.java" + ), + incidents=[ + ExtendedIncident( + uri="file:///home/jonah/Projects/github.com/konveyor-ecosystem/kai-jonah/example/coolstore/src/main/java/com/redhat/coolstore/service/ShippingService.java", + message='Remote EJBs are not supported in Quarkus, and therefore its use must be removed and replaced with REST functionality. In order to do this:\n 1. Replace the `@Remote` annotation on the class with a `@jakarta.ws.rs.Path("")` annotation. An endpoint must be added to the annotation in place of `` to specify the actual path to the REST service.\n 2. Remove `@Stateless` annotations if present. Given that REST services are stateless by nature, it makes it unnecessary.\n 3. For every public method on the EJB being converted, do the following:\n - In case the method has no input parameters, annotate the method with `@jakarta.ws.rs.GET`; otherwise annotate it with `@jakarta.ws.rs.POST` instead.\n - Annotate the method with `@jakarta.ws.rs.Path("")` and give it a proper endpoint path. As a rule of thumb, the method name can be used as endpoint, for instance:\n ```\n @Path("/increment")\n public void increment() \n ```\n - Add `@jakarta.ws.rs.QueryParam("")` to any method parameters if needed, where `` is a name for the parameter.', + code_snip=" 2 \n 3 import java.math.BigDecimal;\n 4 import java.math.RoundingMode;\n 5 \n 6 import javax.ejb.Remote;\n 7 import javax.ejb.Stateless;\n 8 \n 9 import com.redhat.coolstore.model.ShoppingCart;\n10 \n11 @Stateless\n12 @Remote\n13 public class ShippingService implements ShippingServiceRemote {\n14 \n15 @Override\n16 public double calculateShipping(ShoppingCart sc) {\n17 \n18 if (sc != null) {\n19 \n20 if (sc.getCartItemTotal() >= 0 && sc.getCartItemTotal() < 25) {\n21 \n22 return 2.99;", + line_number=12, + variables={ + "file": "file:///home/jonah/Projects/github.com/konveyor-ecosystem/kai-jonah/example/coolstore/src/main/java/com/redhat/coolstore/service/ShippingService.java", + "kind": "Class", + "name": "Stateless", + "package": "com.redhat.coolstore.service", + }, + ruleset_name="quarkus/springboot", + ruleset_description="This ruleset gives hints to migrate from Springboot devtools to Quarkus", + violation_name="remote-ejb-to-quarkus-00000", + violation_description="Remote EJBs are not supported in Quarkus", + violation_category=Category.MANDATORY, + violation_labels=[ + "konveyor.io/source=java-ee", + "konveyor.io/source=jakarta-ee", + "konveyor.io/target=quarkus", + ], + ) + ], + ), + ), + ( + "testRCM", + TestRCMParams( + rcm_root=Path( + "/home/jonah/Projects/github.com/konveyor-ecosystem/kai-jonah/" + ), + file_path=Path( + "/home/jonah/Projects/github.com/konveyor-ecosystem/kai-jonah/test_file.py" + ), + new_content="print('Hello, world!')", + ), + ), + ] +) + +# CONFIGURATION_EDITOR = ConfigurationEditorOld() + +json_rpc_responses: list[dict[str, Any]] = [] +rpc_subprocess_stderr_log = [] +rpc_subprocess = None + +rpc_script_path = Path(os.path.dirname(os.path.realpath(__file__))) / "main.py" +rpc_server: JsonRpcServer | None = None + + +def start_server(command: list[str]) -> None: + global rpc_subprocess + global rpc_subprocess_stderr_log + global rpc_server + + # trunk-ignore-begin(bandit/B603) + rpc_subprocess = subprocess.Popen( + command, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + # trunk-ignore-end(bandit/B603) + + rpc_subprocess_stderr_log.append("Subprocess started.") + + def read_stderr() -> None: + global rpc_subprocess_stderr_log + + while True: + stderr_line = cast(IO[bytes], rpc_subprocess.stderr).readline() + if stderr_line: + rpc_subprocess_stderr_log.append(stderr_line.decode("utf-8").strip()) + SUBPROCESS_INSPECTOR.scroll_to_bottom = True + else: + break + + threading.Thread(target=read_stderr, daemon=True).start() + + rpc_server = JsonRpcServer( + json_rpc_stream=BareJsonStream( + cast(BufferedReader, rpc_subprocess.stdout), + cast(BufferedWriter, rpc_subprocess.stdin), + ), + request_timeout=None, + app=rpc_application, + ) + rpc_server.start() + + +def stop_server() -> None: + global rpc_subprocess + global rpc_server + global rpc_subprocess_stderr_log + + if rpc_subprocess is not None: + rpc_subprocess.terminate() + rpc_subprocess = None + rpc_subprocess_stderr_log.append("Subprocess terminated.") + + if rpc_server is not None: + rpc_server.stop() + rpc_server = None + + +def submit_json_rpc_request(kind: str, method: str, params: Any) -> None: + global rpc_server + + try: + params_dict = json.loads(params) + except json.JSONDecodeError as e: + print(f"Invalid JSON: {e}") + return + + if rpc_server is None: + print("RPC server is not running.") + return + + def asyncly_send_request() -> None: + table_request = { + "method": method, + "params": params_dict, + } + + if kind == "Request": + table_request["id"] = rpc_server.next_id + + idx = len(json_rpc_responses) + json_rpc_responses.append( + {"kind": kind, "request": table_request, "response": None} + ) + + if kind == "Request": + response = rpc_server.send_request(method, params_dict) + elif kind == "Notification": + json_rpc_responses.append( + {"kind": kind, "request": table_request, "response": None} + ) + + response = None + else: + raise ValueError(f"Invalid RPC kind: {kind}") + + if response is None: + json_rpc_responses[idx]["response"] = {"note": "Response is None"} + else: + json_rpc_responses[idx]["response"] = response.model_dump( + exclude={"jsonrpc"} + ) + + threading.Thread(target=asyncly_send_request).start() + + +# Main loop +def main() -> None: + window, gl_context = impl_pysdl2_init() + + imgui.create_context() + impl = SDL2Renderer(window) + imgui_io = imgui.get_io() + imgui_io.fonts.add_font_default() + imgui_io.font_global_scale = 1.3 + imgui.style_colors_dark() + + event = SDL_Event() + running = True + while running: + while SDL_PollEvent(ctypes.byref(event)) != 0: + if event.type == SDL_QUIT: + running = False + break + impl.process_event(event) + impl.process_inputs() + + imgui.new_frame() + + imgui.begin_main_menu_bar() + + if imgui.begin_menu("View"): + clicked, SOURCE_EDITOR.show = imgui.menu_item( + "Source Editor", selected=SOURCE_EDITOR.show + ) + clicked, JSON_RPC_REQUEST_WINDOW.show = imgui.menu_item( + "JSON RPC Request Window", selected=JSON_RPC_REQUEST_WINDOW.show + ) + clicked, REQUEST_RESPONSE_INSPECTOR.show = imgui.menu_item( + "Request/Response Inspector", selected=REQUEST_RESPONSE_INSPECTOR.show + ) + clicked, SUBPROCESS_INSPECTOR.show = imgui.menu_item( + "Subprocess Inspector", selected=SUBPROCESS_INSPECTOR.show + ) + clicked, CONFIGURATION_EDITOR.show = imgui.menu_item( + "Configuration Editor", selected=CONFIGURATION_EDITOR.show + ) + clicked, GIT_VFS_INSPECTOR.show = imgui.menu_item( + "Git VFS Inspector", selected=GIT_VFS_INSPECTOR.show + ) + imgui.end_menu() + + imgui.end_main_menu_bar() + + SOURCE_EDITOR.draw() + FILE_LOADER.draw() + JSON_RPC_REQUEST_WINDOW.draw() + REQUEST_RESPONSE_INSPECTOR.draw() + SUBPROCESS_INSPECTOR.draw() + CONFIGURATION_EDITOR.draw() + GIT_VFS_INSPECTOR.draw() + + gl.glClearColor(0, 0, 0, 1) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + + imgui.render() + impl.render(imgui.get_draw_data()) + SDL_GL_SwapWindow(window) + + impl.shutdown() + SDL_GL_DeleteContext(gl_context) + SDL_DestroyWindow(window) + SDL_Quit() + + +def impl_pysdl2_init() -> tuple[int, Any]: + width, height = 1280, 720 + window_name = "Fake GUI for RPC Server" + + if SDL_Init(SDL_INIT_EVERYTHING) < 0: + print( + "Error: SDL could not initialize! SDL Error: " + + SDL_GetError().decode("utf-8") + ) + sys.exit(1) + + SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1) + SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 24) + SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 8) + SDL_GL_SetAttribute(SDL_GL_ACCELERATED_VISUAL, 1) + SDL_GL_SetAttribute(SDL_GL_MULTISAMPLEBUFFERS, 1) + SDL_GL_SetAttribute(SDL_GL_MULTISAMPLESAMPLES, 8) + SDL_GL_SetAttribute(SDL_GL_CONTEXT_FLAGS, SDL_GL_CONTEXT_FORWARD_COMPATIBLE_FLAG) + SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 4) + SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 1) + SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE) + + SDL_SetHint(SDL_HINT_MAC_CTRL_CLICK_EMULATE_RIGHT_CLICK, b"1") + SDL_SetHint(SDL_HINT_VIDEO_HIGHDPI_DISABLED, b"1") + + window = SDL_CreateWindow( + window_name.encode("utf-8"), + SDL_WINDOWPOS_CENTERED, + SDL_WINDOWPOS_CENTERED, + width, + height, + SDL_WINDOW_OPENGL | SDL_WINDOW_RESIZABLE, + ) + + if window is None: + print( + "Error: Window could not be created! SDL Error: " + + SDL_GetError().decode("utf-8") + ) + sys.exit(1) + + gl_context = SDL_GL_CreateContext(window) + if gl_context is None: + print( + "Error: Cannot create OpenGL Context! SDL Error: " + + SDL_GetError().decode("utf-8") + ) + sys.exit(1) + + SDL_GL_MakeCurrent(window, gl_context) + if SDL_GL_SetSwapInterval(1) < 0: + print( + "Warning: Unable to set VSync! SDL Error: " + SDL_GetError().decode("utf-8") + ) + sys.exit(1) + + return window, gl_context + + +if __name__ == "__main__": + main() diff --git a/playpen/rpc_server/fake_ide.py b/playpen/middleman/fake_ide.py similarity index 94% rename from playpen/rpc_server/fake_ide.py rename to playpen/middleman/fake_ide.py index 0d81c511..91cfa211 100644 --- a/playpen/rpc_server/fake_ide.py +++ b/playpen/middleman/fake_ide.py @@ -14,8 +14,10 @@ from typing import IO, cast from kai.models.kai_config import KaiConfigModels -from playpen.rpc_server.rpc import BareJsonStream, JsonRpcServer, get_logger -from playpen.rpc_server.server import KaiRpcApplication +from playpen.middleman.server import KaiRpcApplication +from playpen.rpc.core import JsonRpcServer +from playpen.rpc.streams import BareJsonStream +from playpen.rpc.util import get_logger log = get_logger("jsonrpc") diff --git a/playpen/rpc_server/main.py b/playpen/middleman/main.py similarity index 81% rename from playpen/rpc_server/main.py rename to playpen/middleman/main.py index 16c7805e..b271a536 100644 --- a/playpen/rpc_server/main.py +++ b/playpen/middleman/main.py @@ -3,8 +3,10 @@ from io import BufferedReader, BufferedWriter from typing import cast -from playpen.rpc_server.rpc import TRACE, BareJsonStream, JsonRpcServer, get_logger -from playpen.rpc_server.server import app +from playpen.middleman.server import app +from playpen.rpc.core import JsonRpcServer +from playpen.rpc.streams import BareJsonStream +from playpen.rpc.util import TRACE, get_logger log = get_logger("jsonrpc") @@ -18,6 +20,7 @@ def main() -> None: add_arguments(parser) _args = parser.parse_args() + log.setLevel(TRACE) log.info("Starting Kai RPC Server") log.log(TRACE, "Trace log level enabled") diff --git a/playpen/middleman/server.py b/playpen/middleman/server.py new file mode 100644 index 00000000..41c7e1fc --- /dev/null +++ b/playpen/middleman/server.py @@ -0,0 +1,446 @@ +import logging +import sys +import threading +import traceback +from pathlib import Path +from typing import Any, Optional, cast +from unittest.mock import MagicMock +from urllib.parse import urlparse + +from pydantic import BaseModel + +from kai.models.kai_config import KaiConfigModels +from kai.models.report_types import ExtendedIncident, Incident, RuleSet, Violation +from kai.service.llm_interfacing.model_provider import ModelProvider +from playpen.repo_level_awareness.agent.dependency_agent.dependency_agent import ( + MavenDependencyAgent, +) +from playpen.repo_level_awareness.agent.reflection_agent import ReflectionAgent +from playpen.repo_level_awareness.api import RpcClientConfig, Task, TaskResult +from playpen.repo_level_awareness.codeplan import TaskManager +from playpen.repo_level_awareness.task_runner.analyzer_lsp.api import ( + AnalyzerRuleViolation, +) +from playpen.repo_level_awareness.task_runner.analyzer_lsp.task_runner import ( + AnalyzerTaskRunner, +) +from playpen.repo_level_awareness.task_runner.analyzer_lsp.validator import ( + AnalyzerLSPStep, +) +from playpen.repo_level_awareness.task_runner.compiler.compiler_task_runner import ( + MavenCompilerTaskRunner, +) +from playpen.repo_level_awareness.task_runner.compiler.maven_validator import ( + MavenCompileStep, +) +from playpen.repo_level_awareness.task_runner.dependency.task_runner import ( + DependencyTaskRunner, +) +from playpen.repo_level_awareness.vfs.git_vfs import ( + RepoContextManager, + RepoContextSnapshot, +) +from playpen.rpc.core import JsonRpcApplication, JsonRpcServer +from playpen.rpc.logs import JsonRpcLoggingHandler +from playpen.rpc.models import JsonRpcError, JsonRpcErrorCode, JsonRpcId +from playpen.rpc.util import DEFAULT_FORMATTER, TRACE, CamelCaseBaseModel + + +class KaiRpcApplicationConfig(CamelCaseBaseModel): + process_id: Optional[int] + + root_path: Path + model_provider: KaiConfigModels + kai_backend_url: str + + log_level: str = "INFO" + stderr_log_level: str = "TRACE" + file_log_level: Optional[str] = None + log_dir_path: Optional[Path] = None + + analyzer_lsp_lsp_path: Path + analyzer_lsp_rpc_path: Path + analyzer_lsp_rules_path: Path + analyzer_lsp_java_bundle_path: Path + + +class KaiRpcApplication(JsonRpcApplication): + def __init__(self) -> None: + super().__init__() + + self.initialized = False + self.config: Optional[KaiRpcApplicationConfig] = None + self.log = logging.getLogger("kai_rpc_application") + + +app = KaiRpcApplication() + +ERROR_NOT_INITIALIZED = JsonRpcError( + code=JsonRpcErrorCode.ServerErrorStart, + message="Server not initialized", +) + + +@app.add_request(method="echo") +def echo( + app: KaiRpcApplication, server: JsonRpcServer, id: JsonRpcId, params: dict[str, Any] +) -> None: + server.send_response(id=id, result=params) + + +@app.add_request(method="shutdown") +def shutdown( + app: KaiRpcApplication, server: JsonRpcServer, id: JsonRpcId, params: dict[str, Any] +) -> None: + server.shutdown_flag = True + + server.send_response(id=id, result={}) + + +@app.add_request(method="exit") +def exit( + app: KaiRpcApplication, server: JsonRpcServer, id: JsonRpcId, params: dict[str, Any] +) -> None: + server.shutdown_flag = True + + server.send_response(id=id, result={}) + + +# NOTE(shawn-hurley): would it ever make sense to have the server +# "re-initialized" or would you just shutdown and restart the process? +@app.add_request(method="initialize") +def initialize( + app: KaiRpcApplication, + server: JsonRpcServer, + id: JsonRpcId, + params: KaiRpcApplicationConfig, +) -> None: + if app.initialized: + server.send_response( + id=id, + error=JsonRpcError( + code=JsonRpcErrorCode.ServerErrorStart, + message="Server already initialized", + ), + ) + return + + try: + app.config = params + + app.config.root_path = app.config.root_path.resolve() + app.config.analyzer_lsp_rpc_path = app.config.analyzer_lsp_rpc_path.resolve() + if app.config.log_dir_path: + app.config.log_dir_path = app.config.log_dir_path.resolve() + + app.log.setLevel(TRACE) + app.log.handlers.clear() + app.log.filters.clear() + + stderr_handler = logging.StreamHandler(sys.stderr) + stderr_handler.setLevel(TRACE) + stderr_handler.setFormatter(DEFAULT_FORMATTER) + app.log.addHandler(stderr_handler) + + notify_handler = JsonRpcLoggingHandler(server) + notify_handler.setLevel(app.config.log_level) + notify_handler.setFormatter(DEFAULT_FORMATTER) + app.log.addHandler(notify_handler) + + if app.config.file_log_level and app.config.log_dir_path: + log_file = app.config.log_dir_path / "kai_rpc.log" + log_file.parent.mkdir(parents=True, exist_ok=True) + + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(app.config.file_log_level) + file_handler.setFormatter(DEFAULT_FORMATTER) + app.log.addHandler(file_handler) + + app.log.info(f"Initialized with config: {app.config}") + + except Exception: + server.send_response( + id=id, + error=JsonRpcError( + code=JsonRpcErrorCode.InternalError, + message=str(traceback.format_exc()), + ), + ) + return + + app.initialized = True + + server.send_response(id=id, result=app.config.model_dump()) + + +# NOTE(shawn-hurley): I would just as soon make this another initialize call +# rather than a separate endpoint. but open to others feedback. +@app.add_request(method="setConfig") +def set_config( + app: KaiRpcApplication, server: JsonRpcServer, id: JsonRpcId, params: dict[str, Any] +) -> None: + if not app.initialized: + server.send_response(id=id, error=ERROR_NOT_INITIALIZED) + return + app.config = cast(KaiRpcApplicationConfig, app.config) + + # Basically reset everything + app.initialized = False + try: + initialize.func(app, server, id, KaiRpcApplicationConfig.model_validate(params)) + except Exception as e: + server.send_response( + id=id, + error=JsonRpcError( + code=JsonRpcErrorCode.InternalError, + message=str(e), + ), + ) + + +# @app.add_request(method="getRAGSolution") +# def get_rag_solution( +# app: KaiRpcApplication, +# server: JsonRpcServer, +# id: JsonRpcId, +# params: PostGetIncidentSolutionsForFileParams, +# ) -> None: +# if not app.initialized: +# server.send_response(id=id, error=ERROR_NOT_INITIALIZED) +# return +# app.config = cast(KaiRpcApplicationConfig, app.config) + +# # NOTE: This is not at all what we should be doing +# try: +# app.log.info(f"get_rag_solution: {params}") +# params_dict = params.model_dump() +# result = requests.post( +# f"{app.config.kai_backend_url}/get_incident_solutions_for_file", +# json=params_dict, +# timeout=1024, +# ) +# app.log.info(f"get_rag_solution result: {result}") +# app.log.info(f"get_rag_solution result.json(): {result.json()}") + +# server.send_response(id=id, result=dict(result.json())) +# except Exception: +# server.send_response( +# id=id, +# error=JsonRpcError( +# code=JsonRpcErrorCode.InternalError, +# message=str(traceback.format_exc()), +# ), +# ) + + +class GetCodeplanAgentSolutionParams(BaseModel): + file_path: Path + incidents: list[ExtendedIncident] + + +class GitVFSUpdateParams(BaseModel): + work_tree: str # project root + git_dir: str + git_sha: str + diff: str + msg: str + spawning_result: Optional[str] + + children: list["GitVFSUpdateParams"] + + @classmethod + def from_snapshot(cls, snapshot: RepoContextSnapshot) -> "GitVFSUpdateParams": + if snapshot.parent: + diff_result = snapshot.diff(snapshot.parent) + diff = diff_result[1] + diff_result[2] + else: + diff = "" + + try: + spawning_result = repr(snapshot.spawning_result) + except Exception: + spawning_result = "" + + return cls( + work_tree=str(snapshot.work_tree), + git_dir=str(snapshot.git_dir), + git_sha=snapshot.git_sha, + diff=diff, + msg=snapshot.msg, + children=[cls.from_snapshot(c) for c in snapshot.children], + spawning_result=spawning_result, + ) + + # spawning_result: Optional[SpawningResult] = None + + +class TestRCMParams(BaseModel): + rcm_root: Path + file_path: Path + new_content: str + + +@app.add_request(method="testRCM") +def test_rcm( + app: KaiRpcApplication, + server: JsonRpcServer, + id: JsonRpcId, + params: TestRCMParams, +) -> None: + rcm = RepoContextManager( + project_root=params.rcm_root, + reflection_agent=ReflectionAgent( + llm=MagicMock(), + ), + ) + + with open(params.file_path, "w") as f: + f.write(params.new_content) + + rcm.commit("testRCM") + + diff = rcm.snapshot.diff(rcm.first_snapshot) + + rcm.reset_to_first() + + server.send_response( + id=id, + result={ + "diff": diff[1] + diff[2], + }, + ) + + +@app.add_request(method="getCodeplanAgentSolution") +def get_codeplan_agent_solution( + app: KaiRpcApplication, + server: JsonRpcServer, + id: JsonRpcId, + params: GetCodeplanAgentSolutionParams, +) -> None: + # create a set of AnalyzerRuleViolations + # seed the task manager with these violations + # get the task with priority 0 and do the whole thingy + + app.log.info(f"get_codeplan_agent_solution: {params}") + if not app.initialized: + server.send_response(id=id, error=ERROR_NOT_INITIALIZED) + return + + app.config = cast(KaiRpcApplicationConfig, app.config) + + try: + model_provider = ModelProvider(app.config.model_provider) + except Exception as e: + server.send_response( + id=id, + error=JsonRpcError( + code=JsonRpcErrorCode.InternalError, + message=str(e), + ), + ) + return + + # Data for AnalyzerRuleViolation should probably take an ExtendedIncident + seed_tasks: list[Task] = [] + + for incident in params.incidents: + seed_tasks.append( + AnalyzerRuleViolation( + file=urlparse(incident.uri).path, + line=incident.line_number, + column=-1, # Not contained within report? + message=incident.message, + priority=0, + incident=Incident(**incident.model_dump()), + violation=Violation( + description=incident.violation_description or "", + category=incident.violation_category, + labels=incident.violation_labels, + ), + ruleset=RuleSet( + name=incident.ruleset_name, + description=incident.ruleset_description or "", + ), + ) + ) + + rcm = RepoContextManager( + project_root=app.config.root_path, + reflection_agent=ReflectionAgent( + llm=model_provider.llm, iterations=1, retries=3 + ), + ) + + server.send_notification( + "gitVFSUpdate", + GitVFSUpdateParams.from_snapshot(rcm.first_snapshot).model_dump(), + ) + + task_manager_config = RpcClientConfig( + repo_directory=app.config.root_path, + analyzer_lsp_server_binary=app.config.analyzer_lsp_rpc_path, + rules_directory=app.config.analyzer_lsp_rules_path, + analyzer_lsp_path=app.config.analyzer_lsp_lsp_path, + analyzer_java_bundle_path=app.config.analyzer_lsp_java_bundle_path, + label_selector="konveyor.io/target=quarkus || konveyor.io/target=jakarta-ee", + incident_selector=None, + included_paths=None, + ) + + task_manager = TaskManager( + config=task_manager_config, + rcm=rcm, + seed_tasks=seed_tasks, + validators=[ + MavenCompileStep(task_manager_config), + AnalyzerLSPStep(task_manager_config), + ], + agents=[ + AnalyzerTaskRunner(model_provider.llm), + MavenCompilerTaskRunner(model_provider.llm), + DependencyTaskRunner( + MavenDependencyAgent(model_provider.llm, app.config.root_path) + ), + ], + ) + + flag = False + result: TaskResult + for task in task_manager.get_next_task(0): + if flag: + break + + app.log.debug(f"Executing task {task.__class__.__name__}: {task}") + + result = task_manager.execute_task(task) + + app.log.debug(f"Task {task.__class__.__name__} result: {result}") + + task_manager.supply_result(result) + + app.log.debug(f"Executed task {task.__class__.__name__}") + rcm.commit(f"Executed task {task.__class__.__name__}") + + server.send_notification( + "gitVFSUpdate", + GitVFSUpdateParams.from_snapshot(rcm.first_snapshot).model_dump(), + ) + + flag = True + + # FIXME: This is a hack to stop the task_manager as it's hanging trying to stop everything + threading.Thread(target=task_manager.stop).start() + + diff = rcm.snapshot.diff(rcm.first_snapshot) + + rcm.reset_to_first() + + server.send_response( + id=id, + result={ + "encountered_errors": [str(e) for e in result.encountered_errors], + "modified_files": [str(f) for f in result.modified_files], + "diff": diff[1] + diff[2], + }, + ) diff --git a/playpen/.gitignore b/playpen/old_rpc_stuff/.gitignore similarity index 100% rename from playpen/.gitignore rename to playpen/old_rpc_stuff/.gitignore diff --git a/playpen/README.md b/playpen/old_rpc_stuff/README.md similarity index 100% rename from playpen/README.md rename to playpen/old_rpc_stuff/README.md diff --git a/playpen/build.spec b/playpen/old_rpc_stuff/build.spec similarity index 100% rename from playpen/build.spec rename to playpen/old_rpc_stuff/build.spec diff --git a/playpen/client/.gitignore b/playpen/old_rpc_stuff/client/.gitignore similarity index 100% rename from playpen/client/.gitignore rename to playpen/old_rpc_stuff/client/.gitignore diff --git a/playpen/client/README.md b/playpen/old_rpc_stuff/client/README.md similarity index 100% rename from playpen/client/README.md rename to playpen/old_rpc_stuff/client/README.md diff --git a/playpen/old_rpc_stuff/client/__init__.py b/playpen/old_rpc_stuff/client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/playpen/client/anlalyzer_rpc.py b/playpen/old_rpc_stuff/client/anlalyzer_rpc.py similarity index 100% rename from playpen/client/anlalyzer_rpc.py rename to playpen/old_rpc_stuff/client/anlalyzer_rpc.py diff --git a/playpen/client/cli.py b/playpen/old_rpc_stuff/client/cli.py similarity index 100% rename from playpen/client/cli.py rename to playpen/old_rpc_stuff/client/cli.py diff --git a/playpen/client/rpc.py b/playpen/old_rpc_stuff/client/rpc.py similarity index 100% rename from playpen/client/rpc.py rename to playpen/old_rpc_stuff/client/rpc.py diff --git a/playpen/client/run_client.sh b/playpen/old_rpc_stuff/client/run_client.sh similarity index 100% rename from playpen/client/run_client.sh rename to playpen/old_rpc_stuff/client/run_client.sh diff --git a/playpen/rpc-client.js b/playpen/old_rpc_stuff/rpc-client.js similarity index 100% rename from playpen/rpc-client.js rename to playpen/old_rpc_stuff/rpc-client.js diff --git a/playpen/rpc-client.py b/playpen/old_rpc_stuff/rpc-client.py similarity index 100% rename from playpen/rpc-client.py rename to playpen/old_rpc_stuff/rpc-client.py diff --git a/playpen/package.json b/playpen/package.json deleted file mode 100644 index 124230a6..00000000 --- a/playpen/package.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "dependencies": { - "vscode-jsonrpc": "^8.2.1" - } -} diff --git a/playpen/repo_level_awareness/agent/api.py b/playpen/repo_level_awareness/agent/api.py index ea583fb3..30fc2401 100644 --- a/playpen/repo_level_awareness/agent/api.py +++ b/playpen/repo_level_awareness/agent/api.py @@ -10,11 +10,15 @@ class AgentRequest: @dataclass class AgentResult: - encountered_errors: list[str] - modified_files: list[Path] + encountered_errors: list[str] | None + modified_files: list[Path] | None class Agent(ABC): @abstractmethod def execute(self, ask: AgentRequest) -> AgentResult: - pass + """ + If the agent cannot handle the request, it should return an AgentResult + with None values. + """ + ... diff --git a/playpen/repo_level_awareness/agent/ast_diff/base.py b/playpen/repo_level_awareness/agent/ast_diff/base.py index 4db5c312..d136d42a 100644 --- a/playpen/repo_level_awareness/agent/ast_diff/base.py +++ b/playpen/repo_level_awareness/agent/ast_diff/base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod -from collections.abc import Iterator -from typing import Any, Dict, Self +from typing import Any, Self + +from typing_extensions import TypeVar class DiffableSummary(ABC): @@ -9,28 +10,32 @@ class DiffableSummary(ABC): """ @abstractmethod - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Returns a dict representation of info Returns: - Dict[str, Any]: Structured info about this file + dict[str, Any]: Structured info about this file """ pass @abstractmethod - def diff(self, o: Self) -> Dict[str, Any]: + def diff(self, o: Self) -> dict[str, Any]: """Computes diff between current info and another version of same type of info Args: o (Self): Another version of same type of info Returns: - Dict[str, Any]: Structured diff + dict[str, Any]: Structured diff """ pass -class DiffableDict(Dict[str, DiffableSummary], DiffableSummary): +KT = TypeVar("KT") +KV = TypeVar("KV", bound=DiffableSummary) + + +class DiffableDict(dict[KT, KV], DiffableSummary): """A dict that's also a diffable, used to store nested tokens""" def __eq__(self, o: object) -> bool: @@ -43,13 +48,13 @@ def __eq__(self, o: object) -> bool: return True return False - def __iter__(self) -> Iterator[str]: - return iter([v.to_dict() for _, v in self.items()]) + # def __iter__(self) -> Iterator[str]: + # return iter([v.to_dict() for _, v in self.items()]) - def to_dict(self) -> Dict[str, Any]: - return {k: v.to_dict() for k, v in self.items()} + def to_dict(self) -> dict[str, Any]: + return {str(k): v.to_dict() for k, v in self.items()} - def diff(self, o: Self) -> Dict[str, Any]: + def diff(self, o: Self) -> dict[str, Any]: diff = {} added = [o[key].to_dict() for key in set(o.keys()) - set(self.keys())] if added: diff --git a/playpen/repo_level_awareness/agent/ast_diff/java.py b/playpen/repo_level_awareness/agent/ast_diff/java.py index 50ea61b3..153a443b 100644 --- a/playpen/repo_level_awareness/agent/ast_diff/java.py +++ b/playpen/repo_level_awareness/agent/ast_diff/java.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Dict, Self, Set, Union +from typing import Any, Self import tree_sitter as ts @@ -19,14 +19,14 @@ def equal(self, o: object) -> bool: return self.name == o.name and self.params == o.params return False - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: d = {"name": self.name} if self.params: d["parameters"] = self.params return d - def diff(self, o: Self) -> Dict[str, Any]: - diff = {} + def diff(self, o: Self) -> dict[str, Any]: + diff: dict[str, Any] = {} if self == o: return diff if self.params != o.params: @@ -55,8 +55,8 @@ def __eq__(self, o: object) -> bool: ) return False - def to_dict(self) -> Dict[str, Any]: - d = { + def to_dict(self) -> dict[str, Any]: + d: dict[str, Any] = { "name": self.name, "type": self.typ, } @@ -64,7 +64,7 @@ def to_dict(self) -> Dict[str, Any]: d["annotations"] = list(self.annotations) return d - def diff(self, o: Self) -> Dict[str, Any]: + def diff(self, o: Self) -> dict[str, Any]: diff = {} if self.typ != o.typ: diff["type"] = { @@ -98,8 +98,8 @@ def __eq__(self, o: object) -> bool: ) return False - def to_dict(self) -> Dict[str, Any]: - d = {"name": self.name} + def to_dict(self) -> dict[str, Any]: + d: dict[str, Any] = {"name": self.name} if self.parameters: d["parameters"] = self.parameters if self.return_type: @@ -110,8 +110,8 @@ def to_dict(self) -> Dict[str, Any]: d["annotations"] = list(self.annotations) return d - def diff(self, o: Self) -> Dict[str, Any]: - diff = {} + def diff(self, o: Self) -> dict[str, Any]: + diff: dict[str, Any] = {} if self == o: return diff diff["name"] = self.name @@ -148,7 +148,7 @@ class JClass(DiffableSummary): fields: DiffableDict methods: DiffableDict annotations: DiffableDict - interfaces: Set[str] = field(default_factory=set) + interfaces: set[str] = field(default_factory=set) def __hash__(self) -> int: return hash(self.name) @@ -165,8 +165,8 @@ def __eq__(self, o: object) -> bool: ) return False - def to_dict(self) -> Dict[str, Any]: - d = {"name": self.name} + def to_dict(self) -> dict[str, Any]: + d: dict[str, Any] = {"name": self.name} if self.super_class: d["super_class"] = self.super_class if self.interfaces: @@ -179,8 +179,8 @@ def to_dict(self) -> Dict[str, Any]: d["methods"] = self.methods.to_dict() return d - def diff(self, o: Self) -> Dict[str, Any]: - diff = {"name": self.name} + def diff(self, o: Self) -> dict[str, Any]: + diff: dict[str, Any] = {"name": self.name} if self.super_class != o.super_class: diff["super_class"] = { "old": self.super_class, @@ -209,18 +209,18 @@ def diff(self, o: Self) -> Dict[str, Any]: @dataclass class JFile(DiffableSummary): classes: DiffableDict - imports: Set[str] = field(default_factory=set) + imports: set[str] = field(default_factory=set) - def to_dict(self) -> Dict[str, Any]: - d = {} + def to_dict(self) -> dict[str, Any]: + d: dict[str, Any] = {} if self.classes: d["classes"] = self.classes.to_dict() if self.imports: d["imports"] = list(self.imports) return d - def diff(self, o: Self) -> Dict[str, Any]: - diff = {} + def diff(self, o: Self) -> dict[str, Any]: + diff: dict[str, Any] = {} if self.imports != o.imports: diff["imports"] = { "old": list(self.imports), @@ -236,10 +236,10 @@ def diff(self, o: Self) -> Dict[str, Any]: def _extract_java_info(root: ts.Node) -> DiffableSummary: cursor = root.walk() - def traverse(node: ts.Node) -> Union[DiffableDict, JVariable, JMethod, JClass]: + def traverse(node: ts.Node) -> DiffableSummary: match node.type: case "modifiers": - annotations = DiffableDict() + annotations = DiffableDict[str, JAnnotation]() for child in node.children: match child.type: case "marker_annotation" | "annotation": @@ -248,17 +248,27 @@ def traverse(node: ts.Node) -> Union[DiffableDict, JVariable, JMethod, JClass]: for annotation_child in child.children: match annotation_child.type: case "identifier": + if annotation_child.text is None: + raise ValueError( + "Annotation identifier has no text" + ) + annotation_name = annotation_child.text.decode( "utf-8" ) case "annotation_argument_list": + if annotation_child.text is None: + raise ValueError( + "Annotation argument list has no text" + ) + params = ( annotation_child.text.decode("utf-8") .replace("\\n", "", -1) .replace("\\t", "", -1) ) annotation = JAnnotation(annotation_name, params) - annotations[hash(annotation)] = annotation + annotations[str(hash(annotation))] = annotation return annotations case "field_declaration": name = "" @@ -267,11 +277,19 @@ def traverse(node: ts.Node) -> Union[DiffableDict, JVariable, JMethod, JClass]: for field_child in node.children: match field_child.type: case "type_identifier" | "generic_type": + if field_child.text is None: + raise ValueError("Field type has no text") + type = field_child.text.decode("utf-8") + case "variable_declarator": for var_child in field_child.children: + if var_child.text is None: + raise ValueError("Variable declarator has no text") + if var_child.type == "identifier": name = var_child.text.decode("utf-8") + case "modifiers": field_child_info = traverse(field_child) if isinstance(field_child_info, DiffableDict): @@ -285,14 +303,26 @@ def traverse(node: ts.Node) -> Union[DiffableDict, JVariable, JMethod, JClass]: for mt_child in node.children: match mt_child.type: case "identifier": + if mt_child.text is None: + raise ValueError("Method identifier has no text") + name = mt_child.text.decode("utf-8") + case "modifiers": anns = traverse(mt_child) if isinstance(anns, DiffableDict): annotations = anns + case "formal_parameters": + if mt_child.text is None: + raise ValueError("Method parameters have no text") + params = mt_child.text.decode("utf-8") + case "block": + if mt_child.text is None: + raise ValueError("Method block has no text") + body = ( mt_child.text.decode("utf-8") .strip() @@ -306,9 +336,10 @@ def traverse(node: ts.Node) -> Union[DiffableDict, JVariable, JMethod, JClass]: name = "" fields = DiffableDict[str, JVariable]() methods = DiffableDict[str, JMethod]() - interfaces: Set[str] = set() + interfaces: set[str] = set() super_class: str = "" annotations = DiffableDict[str, JAnnotation]() + for class_child in node.children: match class_child.type: case "modifiers": @@ -316,17 +347,22 @@ def traverse(node: ts.Node) -> Union[DiffableDict, JVariable, JMethod, JClass]: if isinstance(mods, DiffableDict): annotations = mods case "identifier": + if class_child.text is None: + raise ValueError("Class identifier has no text") + name = class_child.text.decode("utf-8") case "superclass": + if class_child.text is None: + raise ValueError("Superclass has no text") + super_class = class_child.text.decode("utf-8") case "super_interfaces": - interfaces = set( - [ - i.text.decode("utf-8") - for i in class_child.children - if i.type != "," - ] - ) + interfaces = set() + + for i in class_child.children: + if i.type != "," and i.text is not None: + interfaces.add(i.text.decode("utf-8")) + case "class_body": for cb_child in class_child.children: match cb_child.type: @@ -336,11 +372,11 @@ def traverse(node: ts.Node) -> Union[DiffableDict, JVariable, JMethod, JClass]: case cb_info if isinstance( cb_info, JVariable ): - fields[hash(cb_info)] = cb_info + fields[str(hash(cb_info))] = cb_info case cb_info if isinstance( cb_info, JMethod ): - methods[hash(cb_info)] = cb_info + methods[str(hash(cb_info))] = cb_info return JClass( name=name, fields=fields, @@ -350,11 +386,14 @@ def traverse(node: ts.Node) -> Union[DiffableDict, JVariable, JMethod, JClass]: super_class=super_class, ) - classes = DiffableDict() + classes = DiffableDict[str, DiffableSummary]() imports = set() for child in node.children: match child.type: case "import_declaration": + if child.text is None: + raise ValueError("Import declaration has no text") + imports.add( child.text.decode("utf-8") .replace("import ", "", -1) @@ -366,4 +405,7 @@ def traverse(node: ts.Node) -> Union[DiffableDict, JVariable, JMethod, JClass]: classes[cb_info.name] = cb_info return JFile(classes=classes, imports=imports) + if cursor.node is None: + return DiffableDict() + return traverse(cursor.node) diff --git a/playpen/repo_level_awareness/agent/dependency_agent/api.py b/playpen/repo_level_awareness/agent/dependency_agent/api.py index 44055655..86e5322a 100644 --- a/playpen/repo_level_awareness/agent/dependency_agent/api.py +++ b/playpen/repo_level_awareness/agent/dependency_agent/api.py @@ -1,7 +1,7 @@ # trunk-ignore-begin(ruff/E402) import sys -sys.modules["_elementtree"] = None +sys.modules["_elementtree"] = None # type: ignore[assignment] import json # trunk-ignore(bandit/B405) diff --git a/playpen/repo_level_awareness/agent/dependency_agent/dependency_agent.py b/playpen/repo_level_awareness/agent/dependency_agent/dependency_agent.py index 080570ab..1af63015 100644 --- a/playpen/repo_level_awareness/agent/dependency_agent/dependency_agent.py +++ b/playpen/repo_level_awareness/agent/dependency_agent/dependency_agent.py @@ -4,12 +4,12 @@ # so that we can hook into the parser and add attributes to the element. import sys -sys.modules["_elementtree"] = None +sys.modules["_elementtree"] = None # type: ignore[assignment] import logging from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional +from typing import Optional, cast from langchain.prompts.chat import HumanMessagePromptTemplate from langchain_core.language_models.chat_models import BaseChatModel @@ -51,14 +51,14 @@ class _action: @dataclass class MavenDependencyResult(AgentResult): - final_answer: str - fqdn_response: FQDNResponse - find_in_pom: FindInPomResponse + final_answer: Optional[str] + fqdn_response: Optional[FQDNResponse] + find_in_pom: Optional[FindInPomResponse] @dataclass class _llm_response: - actions: List[_action] + actions: list[_action] final_answer: str @@ -102,16 +102,16 @@ class MavenDependencyAgent(Agent): Arguments: - artifact_id: str - The alias name of the symbol to find the definition for. - group_id: str - The path to the file where the alias is used. - - version: Optinal[str] - The line number where the alias is used. + - version: Optional[str] - The line number where the alias is used. Action: ```python - result = search_fqdn.run(artificat_id="commons-collections4", group_id="org.apache.commons", version="4.5.0-M2") + result = search_fqdn.run(artifact_id="commons-collections4", group_id="org.apache.commons", version="4.5.0-M2") print(result) ``` ### Important Notes: 1 We must always use an exact version from the search for Fully Qualified Domain Name of the dependency that we want to update -3. search_fqdn: use this tool to get the fully qualified domain name of a dependecy. This includes the artificatId, groupId and the version. +3. search_fqdn: use this tool to get the fully qualified domain name of a dependency. This includes the artifactId, groupId and the version. Use this tool to get all references to a symbol in the codebase. This will help you understand how the symbol is used in the codebase. For example, if you want to know where a function is called, you can use this tool. @@ -120,9 +120,9 @@ class MavenDependencyAgent(Agent): Thought: replace com.google.guava/guava with org.apache.commons/commons-collections4 to the latest version. -Thought: I have the groupId and the artificatId for the collections4 library, but I don't have the latest version. +Thought: I have the groupId and the artifactId for the collections4 library, but I don't have the latest version. Action: ```python -result = search_fqdn.run(artificat_id="commons-collections4", group_id="org.apache.commons") +result = search_fqdn.run(artifact_id="commons-collections4", group_id="org.apache.commons") ``` Observation: We now have the fqdn for the commons-collections4 dependency @@ -134,10 +134,11 @@ class MavenDependencyAgent(Agent): Thought: Now that I have the latest version information and the current start_line and end_line I need to replace the dependency Action: ```python - xml = f"{result.groupjId}{result.artifactId}{result.version}" - result = editor._run(relative_file_path="pom.xml", start_line=start_line, end_line=end_line, patch=xml) - print(result) -Observation: The pom.xml file is now updatedd setting the xml at start line with the new dependency to the end line +xml = f"{result.groupId}{result.artifactId}{result.version}" +result = editor._run(relative_file_path="pom.xml", start_line=start_line, end_line=end_line, patch=xml) +print(result) +``` +Observation: The pom.xml file is now updated setting the xml at start line with the new dependency to the end line Final Answer: Updated the guava to the commons-collections4 dependency @@ -179,7 +180,7 @@ def __init__( self.child_agent = FQDNDependencySelectorAgent(llm=llm) self.agent_methods.update({"find_in_pom._run": find_in_pom(project_base)}) - def execute(self, ask: AgentRequest) -> MavenDependencyResult: + def execute(self, ask: AgentRequest) -> AgentResult: if not isinstance(ask, MavenDependencyRequest): return AgentResult(encountered_errors=[], modified_files=[]) @@ -205,10 +206,10 @@ def execute(self, ask: AgentRequest) -> MavenDependencyResult: fix_gen_response = self.__llm.invoke(msg) llm_response = self.parse_llm_response(fix_gen_response.content) # Break out of the while loop, if we don't have a final answer then we need to retry - if not self.__should_continue(llm_response): + if llm_response is None or not llm_response.final_answer: break - # We do not believe that we should not conintue now we have to continue after running the code that is asked to be run. + # We do not believe that we should not continue now we have to continue after running the code that is asked to be run. # The only exception to this rule, is when we actually update the file, that should be handled by the caller. # This happens sometimes that the LLM will stop and wait for more information. @@ -228,8 +229,8 @@ def execute(self, ask: AgentRequest) -> MavenDependencyResult: if not maven_search: for a in llm_response.actions: if "search_fqdn.run" in a.code: - m = self.agent_methods.get("search_fqdn.run") logger.debug("running search for FQDN") + m = self.agent_methods.get("search_fqdn.run", lambda x: None) result = m(a.code) if not result or isinstance(result, list): logger.info("Need to call sub-agent for selecting FQDN") @@ -252,7 +253,7 @@ def execute(self, ask: AgentRequest) -> MavenDependencyResult: if "find_in_pom._run" in a.code: logger.debug("running find in pom") m = self.agent_methods.get("find_in_pom._run") - find_pom_lines = m(a.code) + find_pom_lines = cast(FindInPomResponse, m(a.code)) # We are going to give back the response, The caller should be responsible for running the code generated by the AI. # If we have not take the actions step wise, in the LLM, we need to run all but editor here @@ -266,11 +267,11 @@ def execute(self, ask: AgentRequest) -> MavenDependencyResult: ) def parse_llm_response( - self, content: str | List[str] | Dict + self, content: str | list[str] | dict ) -> Optional[_llm_response]: # We should not expect that the value is anything other than str for the type of # call that we know we are making - if isinstance(content, Dict) or isinstance(content, List): + if isinstance(content, dict) or isinstance(content, list): return None actions = [] @@ -347,8 +348,3 @@ def parse_llm_response( else: observation_str = line.strip() return _llm_response(actions, final_answer) - - def __should_continue(self, llm_response: Optional[_llm_response]) -> bool: - if not llm_response or not llm_response.final_answer: - return True - return False diff --git a/playpen/repo_level_awareness/agent/dependency_agent/dependency_fqdn_selection.py b/playpen/repo_level_awareness/agent/dependency_agent/dependency_fqdn_selection.py index db3f8281..57181b09 100644 --- a/playpen/repo_level_awareness/agent/dependency_agent/dependency_fqdn_selection.py +++ b/playpen/repo_level_awareness/agent/dependency_agent/dependency_fqdn_selection.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Optional from jinja2 import Template from langchain_core.language_models.chat_models import BaseChatModel @@ -22,7 +22,7 @@ class FQDNDependencySelectorRequest(AgentRequest): msg: str code: str - query: List[str] + query: list[str] times: int @@ -109,9 +109,9 @@ def execute(self, ask: AgentRequest) -> FQDNDependencySelectorResult: ) def parse_llm_response( - self, content: str | List[str] | Dict + self, content: str | list[str] | dict ) -> Optional[__llm_response]: - if isinstance(content, Dict) or isinstance(content, List): + if isinstance(content, dict) or isinstance(content, list): return None in_reasoning = False diff --git a/playpen/repo_level_awareness/agent/dependency_agent/test_dependency_agent.py b/playpen/repo_level_awareness/agent/dependency_agent/test_dependency_agent.py index 6be11d49..b38df37b 100644 --- a/playpen/repo_level_awareness/agent/dependency_agent/test_dependency_agent.py +++ b/playpen/repo_level_awareness/agent/dependency_agent/test_dependency_agent.py @@ -5,7 +5,7 @@ # Forcing the reload after changing and re-importing will allow us # to pass the test. -sys.modules["_elementtree"] = None +sys.modules["_elementtree"] = None # type: ignore[assignment] importlib.reload(ET) import os @@ -153,7 +153,7 @@ def test_search_fqdn(self): @dataclass class TestCase: code: str - expected: FQDNResponse + expected: FQDNResponse | None testCases = [ TestCase( diff --git a/playpen/repo_level_awareness/agent/dependency_agent/util.py b/playpen/repo_level_awareness/agent/dependency_agent/util.py index db9f8286..b28d473b 100644 --- a/playpen/repo_level_awareness/agent/dependency_agent/util.py +++ b/playpen/repo_level_awareness/agent/dependency_agent/util.py @@ -1,12 +1,12 @@ # trunk-ignore-begin(ruff/E402) import sys -sys.modules["_elementtree"] = None +sys.modules["_elementtree"] = None # type: ignore[assignment] import os import xml.etree.ElementTree as ET # trunk-ignore(bandit/B405) from pathlib import Path -from typing import Callable, List, Optional +from typing import Callable, Optional import requests @@ -19,7 +19,7 @@ # trunk-ignore-end(ruff/E402) -def search_fqdn_query(query: str) -> Optional[FQDNResponse] | List[FQDNResponse]: +def search_fqdn_query(query: str) -> Optional[FQDNResponse] | list[FQDNResponse]: resp = requests.get( f"https://search.maven.org/solrsearch/select?q={query}", timeout=10 ) @@ -47,7 +47,7 @@ def search_fqdn_query(query: str) -> Optional[FQDNResponse] | List[FQDNResponse] ) -def search_fqdn(code: str) -> Optional[FQDNResponse] | List[FQDNResponse]: +def search_fqdn(code: str) -> Optional[FQDNResponse] | list[FQDNResponse]: query = get_maven_query_from_code(code) return search_fqdn_query(query) diff --git a/playpen/repo_level_awareness/api.py b/playpen/repo_level_awareness/api.py index 93bfca1d..f36d6245 100644 --- a/playpen/repo_level_awareness/api.py +++ b/playpen/repo_level_awareness/api.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields from pathlib import Path -from typing import List, Optional +from typing import Optional @dataclass @@ -10,10 +10,10 @@ class RpcClientConfig: analyzer_lsp_server_binary: Path rules_directory: Path analyzer_lsp_path: Path - analyzer_java_bundle: Path + analyzer_java_bundle_path: Path label_selector: Optional[str] incident_selector: Optional[str] - included_paths: Optional[List[str]] + included_paths: Optional[list[str]] @dataclass(eq=False, kw_only=True) @@ -21,7 +21,7 @@ class Task: priority: int = 10 depth: int = 0 parent: Optional["Task"] = None - children: List["Task"] = field(default_factory=list, compare=False) + children: list["Task"] = field(default_factory=list, compare=False) retry_count: int = 0 max_retries: int = 3 creation_order: int = field(init=False) diff --git a/playpen/repo_level_awareness/codeplan.py b/playpen/repo_level_awareness/codeplan.py index ac831f2f..d4ae1733 100755 --- a/playpen/repo_level_awareness/codeplan.py +++ b/playpen/repo_level_awareness/codeplan.py @@ -2,7 +2,7 @@ import logging from pathlib import Path -from typing import Any, Dict, Generator, List, Optional, Set +from typing import Any, Generator, Optional from langchain_core.language_models.chat_models import BaseChatModel from pydantic import BaseModel, ConfigDict @@ -20,7 +20,7 @@ AnalyzerTaskRunner, ) from playpen.repo_level_awareness.task_runner.analyzer_lsp.validator import ( - AnlayzerLSPStep, + AnalyzerLSPStep, ) from playpen.repo_level_awareness.task_runner.api import TaskRunner from playpen.repo_level_awareness.task_runner.compiler.compiler_task_runner import ( @@ -100,7 +100,7 @@ def codeplan(config: RpcClientConfig, seed_tasks: list[Task]): config, RepoContextManager(config.repo_directory, llm=llm), seed_tasks, - validators=[MavenCompileStep(config), AnlayzerLSPStep(config)], + validators=[MavenCompileStep(config), AnalyzerLSPStep(config)], agents=[ AnalyzerTaskRunner(modelProvider.llm), MavenCompilerTaskRunner(modelProvider.llm), @@ -126,14 +126,14 @@ def __init__( validators: Optional[list[ValidationStep]] = None, agents: Optional[list[TaskRunner]] = None, ) -> None: - self.validators: List[ValidationStep] = [] + self.validators: list[ValidationStep] = [] self.processed_files: list[Path] = [] self.unprocessed_files: list[Path] = [] - self.processed_tasks: Set[Task] = set() - self.task_stacks: Dict[int, List[Task]] = {} - self.ignored_tasks: List[Task] = [] + self.processed_tasks: set[Task] = set() + self.task_stacks: dict[int, list[Task]] = {} + self.ignored_tasks: list[Task] = [] if validators is not None: self.validators.extend(validators) @@ -164,6 +164,7 @@ def execute_task(self, task: Task) -> TaskResult: agent = self.get_agent_for_task(task) logger.info("Agent selected for task: %s", agent) result = agent.execute_task(self.rcm, task) + logger.debug("Task execution result: %s", result) return result @@ -172,6 +173,7 @@ def get_agent_for_task(self, task: Task) -> TaskRunner: if agent.can_handle_task(task): logger.debug("Agent %s can handle task %s", agent, task) return agent + logger.error("No agent available for task: %s", task) raise Exception("No agent available for this task") @@ -189,7 +191,7 @@ def supply_result(self, result: TaskResult) -> None: def run_validators(self) -> list[Task]: logger.info("Running validators.") - validation_tasks: List[Task] = [] + validation_tasks: list[Task] = [] for validator in self.validators: logger.debug("Running validator: %s", validator) @@ -361,7 +363,7 @@ def handle_ignored_task(self, task: Task): "Task %s exceeded max retries and added to ignored tasks.", task ) - def stop(self): + def stop(self) -> None: logger.info("Stopping TaskManager.") for a in self.agents: if hasattr(a, "stop"): @@ -375,5 +377,11 @@ def stop(self): if __name__ == "__main__": - with __import__("ipdb").launch_ipdb_on_exception(): + try: + import ipdb + + with ipdb.launch_ipdb_on_exception(): + main() + + except ImportError: main() diff --git a/playpen/repo_level_awareness/task_runner/analyzer_lsp/api.py b/playpen/repo_level_awareness/task_runner/analyzer_lsp/api.py index e8994ac8..9095535a 100644 --- a/playpen/repo_level_awareness/task_runner/analyzer_lsp/api.py +++ b/playpen/repo_level_awareness/task_runner/analyzer_lsp/api.py @@ -7,6 +7,12 @@ @dataclass(eq=False, kw_only=True) class AnalyzerRuleViolation(ValidationError): incident: Incident + + # NOTE(JonahSussman): Violation contains a list of Incidents, and RuleSet + # contains a list of Violations. We have another class, ExtendedIncident, + # that is a flattened version of this, but it might not contain everything + # we want yet. Maybe there's a better way to create ExtendedIncident. I + # don't think these fields are used anywhere regardless. violation: Violation ruleset: RuleSet # TODO Highest priority? diff --git a/playpen/repo_level_awareness/task_runner/analyzer_lsp/task_runner.py b/playpen/repo_level_awareness/task_runner/analyzer_lsp/task_runner.py index a1b116e4..a1f46711 100644 --- a/playpen/repo_level_awareness/task_runner/analyzer_lsp/task_runner.py +++ b/playpen/repo_level_awareness/task_runner/analyzer_lsp/task_runner.py @@ -2,11 +2,13 @@ import os from dataclasses import dataclass from pathlib import Path +from typing import Optional, cast from jinja2 import Template from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from pygments import lexers +from pygments.lexer import LexerMeta from pygments.util import ClassNotFound from playpen.repo_level_awareness.api import Task, TaskResult @@ -106,11 +108,11 @@ def __init__(self, llm: BaseChatModel) -> None: def refine_task(self, errors: list[str]) -> None: """We currently do not refine the tasks""" - return super().refine_task(errors) + raise NotImplementedError("We currently do not refine the tasks") def can_handle_error(self, errors: list[str]) -> bool: """We currently do not know if we can handle errors""" - return super().can_handle_error(errors) + raise NotImplementedError("We currently do not know if we can handle errors") def can_handle_task(self, task: Task) -> bool: """Will determine if the task if a MavenCompilerError, and if we can handle these issues.""" @@ -147,7 +149,7 @@ def execute_task(self, rcm: RepoContextManager, task: Task) -> TaskResult: with open(task.file, "w") as f: f.write(resp.java_file) - rcm.commit("compiler", resp) + rcm.commit(f"AnalyzerTaskRunner changed file {str(task.file)}", resp) return TaskResult(modified_files=[Path(file_name)], encountered_errors=[]) return TaskResult(modified_files=[], encountered_errors=[]) @@ -156,7 +158,7 @@ def execute_task(self, rcm: RepoContextManager, task: Task) -> TaskResult: def parse_llm_response(self, message: BaseMessage) -> AnalyzerLLMResponse: """Private method that will be used to parse the contents and get the results""" - lines_of_output = message.content.splitlines() + lines_of_output = cast(str, message.content).splitlines() in_java_file = False in_reasoning = False @@ -192,13 +194,13 @@ def parse_llm_response(self, message: BaseMessage) -> AnalyzerLLMResponse: ) -def guess_language(code: str, filename: str = None) -> str: +def guess_language(code: str, filename: Optional[str] = None) -> str: try: - if filename: + if filename is not None: lexer = lexers.guess_lexer_for_filename(filename, code) logger.debug(f"{filename} classified as {lexer.aliases[0]}") else: - lexer = lexers.guess_lexer(code) + lexer = cast(LexerMeta, lexers.guess_lexer(code)) logger.debug(f"Code content classified as {lexer.aliases[0]}\n{code}") return lexer.aliases[0] except ClassNotFound: diff --git a/playpen/repo_level_awareness/task_runner/analyzer_lsp/validator.py b/playpen/repo_level_awareness/task_runner/analyzer_lsp/validator.py index 16f5a004..add96ae9 100644 --- a/playpen/repo_level_awareness/task_runner/analyzer_lsp/validator.py +++ b/playpen/repo_level_awareness/task_runner/analyzer_lsp/validator.py @@ -1,12 +1,16 @@ import logging import subprocess # trunk-ignore(bandit/B404) -from typing import Dict, List +import threading +from io import BufferedReader, BufferedWriter +from typing import IO, Any, cast from urllib.parse import urlparse +from pydantic import BaseModel + from kai.models.report import Report -from playpen.client import anlalyzer_rpc as analyzer_rpc from playpen.repo_level_awareness.api import ( RpcClientConfig, + ValidationError, ValidationResult, ValidationStep, ) @@ -14,61 +18,93 @@ AnalyzerDependencyRuleViolation, AnalyzerRuleViolation, ) +from playpen.rpc.core import JsonRpcServer +from playpen.rpc.models import JsonRpcError +from playpen.rpc.streams import BareJsonStream logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -class AnlayzerLSPStep(ValidationStep): +def log_stderr(stderr: IO[bytes]) -> None: + for line in iter(stderr.readline, b""): + logger.info(line.decode("utf-8")) - rpc: analyzer_rpc.AnalyzerRpcServer - def __init__(self, RpcClientConfig: RpcClientConfig) -> None: +class AnalyzerLSPStep(ValidationStep): + def __init__(self, config: RpcClientConfig) -> None: """This will start and analyzer-lsp jsonrpc server""" # trunk-ignore-begin(bandit/B603) rpc_server = subprocess.Popen( [ - RpcClientConfig.analyzer_lsp_server_binary, + config.analyzer_lsp_server_binary, "-source-directory", - RpcClientConfig.repo_directory, + config.repo_directory, "-rules-directory", - RpcClientConfig.rules_directory, + config.rules_directory, "-lspServerPath", - RpcClientConfig.analyzer_lsp_path, + config.analyzer_lsp_path, "-bundles", - RpcClientConfig.analyzer_java_bundle, + config.analyzer_java_bundle_path, "-log-file", "./kai-analyzer.log", ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, - text=True, + stderr=subprocess.PIPE, ) # trunk-ignore-end(bandit/B603) - self.rpc = analyzer_rpc.AnalyzerRpcServer( - json_rpc_endpoint=analyzer_rpc.AnlayzerRPCEndpoint( - rpc_server.stdin, rpc_server.stdout + self.stderr_logging_thread = threading.Thread( + target=log_stderr, args=(rpc_server.stderr,) + ) + self.stderr_logging_thread.start() + + self.rpc = JsonRpcServer( + json_rpc_stream=BareJsonStream( + cast(BufferedReader, rpc_server.stdout), + cast(BufferedWriter, rpc_server.stdin), ), - timeout=180, + request_timeout=60, ) self.rpc.start() - super().__init__(RpcClientConfig) + super().__init__(config) def run(self) -> ValidationResult: analyzer_output = self.__run_analyzer_lsp() - errors = self.__parse_analyzer_lsp_output(analyzer_output) - return ValidationResult(passed=not errors, errors=errors) - def __run_analyzer_lsp(self) -> List[AnalyzerRuleViolation]: + # TODO: Possibly add messages to the results + ERR = ValidationResult( + passed=False, + errors=[ + ValidationError( + file="", line=-1, column=-1, message="Analyzer LSP failed" + ) + ], + ) + if analyzer_output is None: + return ERR + elif isinstance(analyzer_output, JsonRpcError): + return ERR + elif analyzer_output.result is None: + return ERR + elif isinstance(analyzer_output.result, BaseModel): + analyzer_output.result = analyzer_output.result.model_dump() + + errors = self.__parse_analyzer_lsp_output(analyzer_output.result) + return ValidationResult( + passed=not errors, errors=cast(list[ValidationError], errors) + ) + def __run_analyzer_lsp(self): request_params = { - "label_selector": "konveyor.io/target=quarkus konveyor.io/target=jakarta-ee ", + "label_selector": "konveyor.io/target=quarkus konveyor.io/target=jakarta-ee", "included_paths": [], "incident_selector": "", } + if self.config.label_selector is not None: request_params["label_selector"] = self.config.label_selector @@ -78,14 +114,14 @@ def __run_analyzer_lsp(self) -> List[AnalyzerRuleViolation]: if self.config.incident_selector is not None: request_params["incident_selector"] = self.config.incident_selector - return self.rpc.call_method( + return self.rpc.send_request( "analysis_engine.Analyze", - kwargs=request_params, + params=request_params, ) def __parse_analyzer_lsp_output( - self, analyzer_output: Dict[str, any] - ) -> List[AnalyzerRuleViolation]: + self, analyzer_output: dict[str, Any] + ) -> list[AnalyzerRuleViolation]: rulesets = analyzer_output.get("Rulesets") if not rulesets or not isinstance(rulesets, list): @@ -93,7 +129,7 @@ def __parse_analyzer_lsp_output( r = Report.load_report_from_object(rulesets, "analysis_run_task_runner") - validation_errors: List[AnalyzerRuleViolation] = [] + validation_errors: list[AnalyzerRuleViolation] = [] for _k, v in r.rulesets.items(): for _vk, vio in v.violations.items(): for i in vio.incidents: @@ -104,7 +140,7 @@ def __parse_analyzer_lsp_output( class_to_use( file=urlparse(i.uri).path, line=i.line_number, - column=None, + column=-1, message=i.message, incident=i, violation=vio, @@ -115,4 +151,5 @@ def __parse_analyzer_lsp_output( return validation_errors def stop(self): + self.stderr_logging_thread.join() self.rpc.stop() diff --git a/playpen/repo_level_awareness/task_runner/compiler/compiler_task_runner.py b/playpen/repo_level_awareness/task_runner/compiler/compiler_task_runner.py index b3642bfd..c6fd83da 100644 --- a/playpen/repo_level_awareness/task_runner/compiler/compiler_task_runner.py +++ b/playpen/repo_level_awareness/task_runner/compiler/compiler_task_runner.py @@ -1,7 +1,6 @@ import logging from dataclasses import dataclass, field from pathlib import Path -from typing import List from jinja2 import Template from langchain_core.language_models.chat_models import BaseChatModel @@ -28,10 +27,10 @@ class MavenCompilerLLMResponse(SpawningResult): reasoning: str java_file: str - addional_information: str + additional_information: str file_path: str = "" input_file: str = "" - input_errors: List[str] = field(default_factory=list) + input_errors: list[str] = field(default_factory=list) def to_reflection_task(self) -> ReflectionTask: return ReflectionTask( @@ -44,7 +43,7 @@ def to_reflection_task(self) -> ReflectionTask: class MavenCompilerTaskRunner(TaskRunner): - """This agent is reponsible for taking a set of maven compiler issues and solving. + """This agent is responsible for taking a set of maven compiler issues and solving. For a given file it will asking LLM's for the changes that are needed for the at whole file returning the results. @@ -66,7 +65,7 @@ class MavenCompilerTaskRunner(TaskRunner): You must reason through the required changes and rewrite the Java file to make it compile. - You will then provide an step-by-step explaination of the changes required tso that someone could recreate it in a similar situation. + You will then provide an step-by-step explanation of the changes required tso that someone could recreate it in a similar situation. """ ) @@ -80,7 +79,7 @@ class MavenCompilerTaskRunner(TaskRunner): {{src_file_contents}} - # Ouput Instructions + # Output Instructions Structure your output in Markdown format such as: ## Updated Java File @@ -89,8 +88,8 @@ class MavenCompilerTaskRunner(TaskRunner): ## Reasoning Write the step by step reasoning in this markdown section. If you are unsure of a step or reasoning, clearly state you are unsure and why. - ## Additional Infomation (optional) - If you have additional details or steps that need to be perfomed, put it here. Say I have completed the changes when you are done explaining the reasoning[/INST] + ## Additional Information (optional) + If you have additional details or steps that need to be performed, put it here. Say I have completed the changes when you are done explaining the reasoning[/INST] """ ) @@ -127,11 +126,11 @@ def execute_task(self, rcm: RepoContextManager, task: Task) -> TaskResult: src_file_contents=src_file_contents, compile_errors=compile_errors ) - aimessage = self.__llm.invoke( + ai_message = self.__llm.invoke( [self.system_message, HumanMessage(content=content)] ) - resp = self.parse_llm_response(aimessage) + resp = self.parse_llm_response(ai_message) resp.file_path = task.file resp.input_file = src_file_contents resp.input_errors = [task.message] @@ -140,7 +139,7 @@ def execute_task(self, rcm: RepoContextManager, task: Task) -> TaskResult: with open(task.file, "w") as f: f.write(resp.java_file) - rcm.commit("compiler", resp) + rcm.commit(f"MavenCompilerTaskRunner changed file {str(task.file)}", resp) return TaskResult(modified_files=[Path(task.file)], encountered_errors=[]) @@ -163,7 +162,7 @@ def parse_llm_response(self, message: BaseMessage) -> MavenCompilerLLMResponse: in_java_file = False in_reasoning = True continue - if line.strip() == "## Additional Infomation (optional)": + if line.strip() == "## Additional Information (optional)": in_reasoning = False in_additional_details = True continue @@ -178,5 +177,5 @@ def parse_llm_response(self, message: BaseMessage) -> MavenCompilerLLMResponse: return MavenCompilerLLMResponse( reasoning=reasoning, java_file=java_file, - addional_information=additional_details, + additional_information=additional_details, ) diff --git a/playpen/repo_level_awareness/task_runner/compiler/maven_validator.py b/playpen/repo_level_awareness/task_runner/compiler/maven_validator.py index 30b691b4..d855b9cf 100755 --- a/playpen/repo_level_awareness/task_runner/compiler/maven_validator.py +++ b/playpen/repo_level_awareness/task_runner/compiler/maven_validator.py @@ -4,7 +4,7 @@ import re import subprocess # trunk-ignore(bandit/B404) from dataclasses import dataclass, field -from typing import List, Optional, Type +from typing import Optional, Type from playpen.repo_level_awareness.api import ( ValidationError, @@ -26,7 +26,7 @@ def run(self) -> ValidationResult: @dataclass(eq=False) class MavenCompilerError(ValidationError): - details: List[str] = field(default_factory=list) + details: list[str] = field(default_factory=list) parse_lines: Optional[str] = None priority = 1 @@ -136,15 +136,15 @@ def classify_error(message: str) -> Type[MavenCompilerError]: return OtherError -def parse_maven_output(output: str) -> List[MavenCompilerError]: +def parse_maven_output(output: str) -> list[MavenCompilerError]: """ Parses the Maven output and returns a list of MavenCompilerError instances. """ - errors: List[MavenCompilerError] = [] + errors: list[MavenCompilerError] = [] lines = output.splitlines() in_compilation_error_section = False error_pattern = re.compile(r"\[ERROR\] (.+?):\[(\d+),(\d+)\] (.+)") - current_error: Optional[MavenCompilerError] = None + current_error: MavenCompilerError acc = [] for i, line in enumerate(lines): @@ -168,6 +168,7 @@ def parse_maven_output(output: str) -> List[MavenCompilerError]: acc.append(line) error_class = classify_error(match.group(4)) current_error = error_class.from_match(match, []) + # Look ahead for details details = [] j = i + 1 @@ -176,6 +177,7 @@ def parse_maven_output(output: str) -> List[MavenCompilerError]: detail_line = lines[j].replace("[ERROR] ", "", -1).strip() details.append(detail_line) j += 1 + current_error.details.extend(details) # Extract additional information based on error type if isinstance(current_error, SymbolNotFoundError): @@ -208,6 +210,7 @@ def parse_maven_output(output: str) -> List[MavenCompilerError]: current_error.inaccessible_class = current_error.message.split( "cannot access" )[-1].strip() + current_error.parse_lines = "\n".join(acc) errors.append(current_error) acc = [] diff --git a/playpen/repo_level_awareness/task_runner/dependency/task_runner.py b/playpen/repo_level_awareness/task_runner/dependency/task_runner.py index 3c778c82..e8749517 100644 --- a/playpen/repo_level_awareness/task_runner/dependency/task_runner.py +++ b/playpen/repo_level_awareness/task_runner/dependency/task_runner.py @@ -101,7 +101,9 @@ def execute_task(self, rcm: RepoContextManager, task: Task) -> TaskResult: ET.indent(tree, "\t", 0) pretty_xml = ET.tostring(root, encoding="UTF-8", default_namespace="") p.write(pretty_xml.decode("utf-8")) - rcm.commit("dependnecy", maven_dep_response) + rcm.commit( + f"DependencyTaskRunner changed file {str(pom)}", maven_dep_response + ) return TaskResult(modified_files=[Path(pom)], encountered_errors=[]) diff --git a/playpen/repo_level_awareness/tests/test_priority_queue.py b/playpen/repo_level_awareness/tests/test_priority_queue.py index 787a8013..6b975503 100644 --- a/playpen/repo_level_awareness/tests/test_priority_queue.py +++ b/playpen/repo_level_awareness/tests/test_priority_queue.py @@ -1,5 +1,4 @@ import unittest -from typing import List # Import classes from your codebase from playpen.repo_level_awareness.api import ( @@ -15,7 +14,7 @@ class MockValidationStep(ValidationStep): def __init__( - self, config: RpcClientConfig, error_sequences: List[List[ValidationError]] + self, config: RpcClientConfig, error_sequences: list[list[ValidationError]] ): super().__init__(config) self.error_sequences = error_sequences diff --git a/playpen/repo_level_awareness/utils/xml.py b/playpen/repo_level_awareness/utils/xml.py index 9b7bb22b..cd4671d2 100644 --- a/playpen/repo_level_awareness/utils/xml.py +++ b/playpen/repo_level_awareness/utils/xml.py @@ -3,7 +3,7 @@ # trunk-ignore-begin(ruff/E402) import sys -sys.modules["_elementtree"] = None +sys.modules["_elementtree"] = None # type: ignore[assignment] import xml.etree.ElementTree as ET # trunk-ignore(bandit/B405) # trunk-ignore-end(ruff/E402) @@ -13,7 +13,7 @@ class LineNumberingParser(ET.XMLParser): def _start(self, *args, **kwargs): # Here we assume the default XML parser which is expat # and copy its element position attributes into output Elements - element = super(self.__class__, self)._start(*args, **kwargs) + element = super()._start(*args, **kwargs) element._start_line_number = self.parser.CurrentLineNumber element._start_column_number = self.parser.CurrentColumnNumber element._start_byte_index = self.parser.CurrentByteIndex diff --git a/playpen/repo_level_awareness/vfs/git_vfs.py b/playpen/repo_level_awareness/vfs/git_vfs.py index 65fc4002..88503ef3 100644 --- a/playpen/repo_level_awareness/vfs/git_vfs.py +++ b/playpen/repo_level_awareness/vfs/git_vfs.py @@ -9,9 +9,6 @@ from enum import StrEnum from pathlib import Path from typing import Any, Optional -from unittest.mock import MagicMock - -from langchain_core.language_models.chat_models import BaseChatModel from playpen.repo_level_awareness.agent.api import AgentResult from playpen.repo_level_awareness.agent.reflection_agent import ( @@ -44,7 +41,6 @@ class RepoContextSnapshot: parent: Optional["RepoContextSnapshot"] = None children: list["RepoContextSnapshot"] = field(default_factory=list) - # Narrow down this type, could be task, or errors or what have you spawning_result: Optional[SpawningResult] = None @functools.cached_property @@ -73,6 +69,22 @@ def parent_spawning_results(self) -> list[SpawningResult]: return self.parent.parent_spawning_results + [self.spawning_result] + @functools.cached_property + def lineage(self) -> list["RepoContextSnapshot"]: + """ + Returns the lineage of the current snapshot, starting from the initial + commit. In order from oldest to newest. + """ + lineage: list[RepoContextSnapshot] = [self] + parent = self.parent + while parent is not None: + lineage.append(parent) + parent = parent.parent + + lineage.reverse() + + return lineage + def git(self, args: list[str], popen_kwargs: dict[str, Any] | None = None): """ Execute a git command with the given arguments. Returns a tuple of the @@ -108,11 +120,14 @@ def git(self, args: list[str], popen_kwargs: dict[str, Any] | None = None): return proc.returncode, stdout, stderr @staticmethod - def initialize(work_tree: Path) -> "RepoContextSnapshot": + def initialize(work_tree: Path, msg: str | None = None) -> "RepoContextSnapshot": """ Creates a new git repo in the given work_tree, and returns a GitVFSSnapshot. """ + if msg is None: + msg = "Initial commit" + work_tree = work_tree.resolve() kai_dir = work_tree / ".kai" kai_dir.mkdir(exist_ok=True) @@ -133,7 +148,7 @@ def initialize(work_tree: Path) -> "RepoContextSnapshot": with open(git_dir / "info" / "exclude", "a") as f: f.write(f"/{str(kai_dir.name)}\n") - tmp_snapshot = tmp_snapshot.commit("Initial commit") + tmp_snapshot = tmp_snapshot.commit(msg) return RepoContextSnapshot( work_tree=work_tree, @@ -186,17 +201,28 @@ def reset(self) -> tuple[int, str, str]: """ return self.git(["reset", "--hard", self.git_sha]) + def diff(self, other: "RepoContextSnapshot") -> tuple[int, str, str]: + """ + Returns the diff between the current snapshot and another snapshot. + """ + return self.git(["diff", other.git_sha, self.git_sha]) + class RepoContextManager: - def __init__(self, project_root: Path, llm: BaseChatModel): + def __init__( + self, + project_root: Path, + reflection_agent: ReflectionAgent, + initial_msg: str | None = None, + ): self.project_root = project_root - self.snapshot = RepoContextSnapshot.initialize(project_root) - - self.reflection_agent = ReflectionAgent(llm=llm, iterations=1, retries=3) + self.snapshot = RepoContextSnapshot.initialize(project_root, initial_msg) + self.first_snapshot = self.snapshot + self.reflection_agent = reflection_agent def commit( self, msg: str | None = None, spawning_result: SpawningResult | None = None - ): + ) -> bool: """ Commits the current state of the repository and updates the snapshot. Also runs the reflection agent validate the repository state. @@ -214,6 +240,8 @@ def commit( self.snapshot = self.snapshot.commit(msg, new_spawning_result) + return True + def reset(self, snapshot: Optional[RepoContextSnapshot] = None): """ Resets the repository to the given snapshot. If no snapshot is provided, @@ -234,11 +262,25 @@ def reset_to_parent(self): self.reset(self.snapshot.parent) + def reset_to_first(self) -> None: + """ + Resets the repository to the initial commit. + """ + while self.snapshot.parent is not None: + self.reset_to_parent() + + def get_lineage(self) -> list[RepoContextSnapshot]: + """ + Returns the lineage of the current snapshot, starting from the initial + commit. The current snapshot is the first element in the list. + """ + return self.snapshot.lineage + # FIXME: remove this function, only there for the little demo below so the # pseudo code works def union_the_result_and_the_errors(*args, **kwargs): - pass + return args[0] if __name__ == "__main__": @@ -262,7 +304,7 @@ def dfs( args = parser.parse_args() - manager = RepoContextManager(args.project_root, llm=MagicMock()) + manager = RepoContextManager(args.project_root) first_snapshot = manager.snapshot class Command(StrEnum): diff --git a/playpen/rpc/__init__.py b/playpen/rpc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/playpen/rpc/callbacks.py b/playpen/rpc/callbacks.py new file mode 100644 index 00000000..85a42aaa --- /dev/null +++ b/playpen/rpc/callbacks.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import functools +import inspect +import traceback +from typing import TYPE_CHECKING, Any, Callable, Literal + +from pydantic import BaseModel, ConfigDict, validate_call + +from playpen.rpc.models import JsonRpcError, JsonRpcErrorCode, JsonRpcRequest +from playpen.rpc.util import TRACE, get_logger + +if TYPE_CHECKING: + from playpen.rpc.core import JsonRpcApplication, JsonRpcServer + + +log = get_logger("jsonrpc") + +# JsonRpcCallable = Callable[[JsonRpcApplication, JsonRpcServer, JsonRpcId, JsonRpcRequestParams], None] +JsonRpcCallable = Callable[..., None] + + +class JsonRpcCallback: + """ + A JsonRpcCallback is a wrapper around a JsonRpcMethodCallable or + JsonRpcNotifyCallable. It validates the parameters and calls the function. + + We use this class to allow for more flexibility in the parameters that can + be passed to the function. + """ + + def __init__( + self, + func: JsonRpcCallable, + kind: Literal["request", "notify"], + method: str, + ): + self.func = func + self.kind = kind + self.method = method + + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + @functools.wraps(self.func) + def validate_func_args( + *args: Any, **kwargs: Any + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + return args, kwargs + + self.validate_func_args = validate_func_args + + sig = inspect.signature(func) + self.params_model: type[dict[str, Any]] | type[BaseModel] | None = [ + (p.annotation) for p in sig.parameters.values() + ][3] + + def __call__( + self, + app: JsonRpcApplication, + server: JsonRpcServer, + request: JsonRpcRequest, + ) -> None: + try: + log.log(TRACE, f"{self.func.__name__} called with {request}") + log.log( + TRACE, + f"{[(p.annotation) for p in inspect.signature(self.func).parameters.values()]}", + ) + + validated_params: BaseModel | dict[str, Any] | None + + if request.params is None: + if self.params_model is not None: + raise ValueError("Expected params to be present") + else: + validated_params = None + elif isinstance(request.params, dict): + if self.params_model is None: + raise ValueError("Expected params to be None") + elif self.params_model is dict: + validated_params = request.params + elif issubclass(self.params_model, BaseModel): + validated_params = self.params_model.model_validate(request.params) + else: + raise ValueError( + f"params_model should be dict | BaseModel | None, got {self.params_model}" + ) + else: + raise ValueError( + f"Expected params to be a dict or None, got {type(request.params)}" + ) + + log.log(TRACE, f"Validated params: {validated_params}") + self.validate_func_args(app, server, request.id, validated_params) + + log.log(TRACE, f"Calling function: {self.func.__name__}") + self.func(app, server, request.id, validated_params) + except Exception: + server.send_response( + id=request.id, + error=JsonRpcError( + code=JsonRpcErrorCode.InternalError, + message=traceback.format_exc(), + ), + ) diff --git a/playpen/rpc/core.py b/playpen/rpc/core.py new file mode 100644 index 00000000..f1735d61 --- /dev/null +++ b/playpen/rpc/core.py @@ -0,0 +1,325 @@ +import threading +from typing import Any, Callable, Literal, Optional, overload + +from pydantic import BaseModel + +from playpen.rpc.callbacks import JsonRpcCallable, JsonRpcCallback +from playpen.rpc.models import ( + JsonRpcError, + JsonRpcErrorCode, + JsonRpcId, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResult, +) +from playpen.rpc.streams import JsonRpcStream +from playpen.rpc.util import TRACE, get_logger + +log = get_logger("jsonrpc") + + +class JsonRpcApplication: + """ + Taking a page out of the ASGI standards, JsonRpcApplication is a collection + of JsonRpcCallbacks that can be used to handle incoming requests and + notifications. + """ + + def __init__( + self, + request_callbacks: Optional[dict[str, JsonRpcCallback]] = None, + notify_callbacks: Optional[dict[str, JsonRpcCallback]] = None, + ): + if request_callbacks is None: + request_callbacks = {} + if notify_callbacks is None: + notify_callbacks = {} + + self.request_callbacks = request_callbacks + self.notify_callbacks = notify_callbacks + + def handle_request(self, request: JsonRpcRequest, server: "JsonRpcServer") -> None: + log.log(TRACE, "Handling request: %s", request) + + if request.id is not None: + log.log(TRACE, "Request is a request") + + if request.method not in self.request_callbacks: + server.send_response( + error=JsonRpcError( + code=JsonRpcErrorCode.MethodNotFound, + message=f"Method not found: {request.method}", + ), + id=request.id, + ) + return + + log.log(TRACE, "Calling method: %s", request.method) + + self.request_callbacks[request.method]( + request=request, server=server, app=self + ) + + else: + log.log(TRACE, "Request is a notification") + + if request.method not in self.notify_callbacks: + log.error(f"Notify method not found: {request.method}") + return + + log.log(TRACE, "Calling method: %s", request.method) + + self.notify_callbacks[request.method]( + request=request, server=server, app=self + ) + + @overload + def add( + self, + func: JsonRpcCallable, + *, + kind: Literal["request", "notify"] = ..., + method: str | None = ..., + ) -> JsonRpcCallback: ... + + @overload + def add( + self, + func: None = ..., + *, + kind: Literal["request", "notify"] = ..., + method: str | None = ..., + ) -> Callable[[JsonRpcCallable], JsonRpcCallback]: ... + + def add( + self, + func: JsonRpcCallable | None = None, + *, + kind: Literal["request", "notify"] = "request", + method: str | None = None, + ) -> JsonRpcCallback | Callable[[JsonRpcCallable], JsonRpcCallback]: + if method is None: + raise ValueError("Method name must be provided") + + if kind == "request": + callbacks = self.request_callbacks + else: + callbacks = self.notify_callbacks + + def decorator( + func: JsonRpcCallable, + ) -> JsonRpcCallback: + callback = JsonRpcCallback( + func=func, + kind=kind, + method=method, + ) + callbacks[method] = callback + + return callback + + if func: + return decorator(func) + else: + return decorator + + @overload + def add_notify( + self, + func: JsonRpcCallable, + *, + method: str | None = ..., + ) -> JsonRpcCallback: ... + + @overload + def add_notify( + self, + func: None = ..., + *, + method: str | None = ..., + ) -> Callable[[JsonRpcCallable], JsonRpcCallback]: ... + + def add_notify( + self, + func: JsonRpcCallable | None = None, + *, + method: str | None = None, + ) -> JsonRpcCallback | Callable[[JsonRpcCallable], JsonRpcCallback]: + return self.add( + func=func, + kind="notify", + method=method, + ) + + @overload + def add_request( + self, + func: JsonRpcCallable, + *, + method: str | None = ..., + ) -> JsonRpcCallback: ... + + @overload + def add_request( + self, + func: None = ..., + *, + method: str | None = ..., + ) -> Callable[[JsonRpcCallable], JsonRpcCallback]: ... + + def add_request( + self, + func: JsonRpcCallable | None = None, + *, + method: str | None = None, + ) -> JsonRpcCallback | Callable[[JsonRpcCallable], JsonRpcCallback]: + return self.add( + func=func, + kind="request", + method=method, + ) + + def generate_docs(self) -> str: + raise NotImplementedError() + + +class JsonRpcServer(threading.Thread): + """ + Taking a page from Python's ASGI standards, JsonRpcServer serves a + JsonRpcApplication. It is a thread that listens for incoming requests and + notifications on a JsonRpcStream, and sends responses over the same stream. + + We separate the two classes to allow one to define routes in different + files. + + Despite being called "server", you can also use this as a client. + """ + + def __init__( + self, + json_rpc_stream: JsonRpcStream, + app: JsonRpcApplication | None = None, + request_timeout: float | None = 60.0, + ): + if app is None: + app = JsonRpcApplication() + + threading.Thread.__init__(self) + + self.jsonrpc_stream = json_rpc_stream + self.app = app + + self.event_dict: dict[JsonRpcId, threading.Condition] = {} + self.response_dict: dict[JsonRpcId, JsonRpcResponse] = {} + self.next_id = 0 + self.request_timeout = request_timeout + self.outstanding_requests: set[JsonRpcId] = set() + + self.shutdown_flag = False + + def stop(self) -> None: + self.shutdown_flag = True + + def run(self) -> None: + log.debug("Server thread started") + + while not self.shutdown_flag: + msg = self.jsonrpc_stream.recv() + if msg is None: + log.info("Server quit") + break + + elif isinstance(msg, JsonRpcError): + self.jsonrpc_stream.send(JsonRpcResponse(error=msg)) + continue + + elif isinstance(msg, JsonRpcRequest): + log.log(TRACE, "Received request: %s", msg) + if msg.id is not None: + self.outstanding_requests.add(msg.id) + + self.app.handle_request(msg, self) + + if msg.id is not None and msg.id in self.outstanding_requests: + self.send_response( + id=msg.id, + error=JsonRpcError( + code=JsonRpcErrorCode.InternalError, + message="No response sent", + ), + ) + + elif isinstance(msg, JsonRpcResponse): + self.response_dict[msg.id] = msg + cond = self.event_dict[msg.id] + cond.acquire() + cond.notify() + cond.release() + + else: + log.error(f"Unknown message type: {type(msg)}") + + self.jsonrpc_stream.close() + + def send_request( + self, method: str, params: BaseModel | dict[str, Any] | None + ) -> JsonRpcResponse | JsonRpcError | None: + if isinstance(params, BaseModel): + params = params.model_dump() + + log.log(TRACE, "Sending request: %s", method) + current_id = self.next_id + self.next_id += 1 + cond = threading.Condition() + self.event_dict[current_id] = cond + + cond.acquire() + self.jsonrpc_stream.send( + JsonRpcRequest(method=method, params=params, id=current_id) + ) + + if self.shutdown_flag: + cond.release() + return None + + if not cond.wait(self.request_timeout): + cond.release() + return JsonRpcError( + code=JsonRpcErrorCode.InternalError, + message="Timeout waiting for response", + ) + cond.release() + + self.event_dict.pop(current_id) + return self.response_dict.pop(current_id) + + def send_notification( + self, + method: str, + params: dict[str, Any] | None, + ) -> None: + if isinstance(params, BaseModel): + params = params.model_dump() + + self.jsonrpc_stream.send(JsonRpcRequest(method=method, params=params)) + + def send_response( + self, + *, + response: Optional[JsonRpcResponse] = None, + result: Optional[JsonRpcResult] = None, + error: Optional[JsonRpcError] = None, + id: JsonRpcId = None, + ) -> None: + if response is None: + response = JsonRpcResponse(result=result, error=error, id=id) + + if response.id is not None: + if response.id not in self.outstanding_requests: + log.error( + f"Request ID {response.id} not found in outstanding requests\nTried sending: {response}" + ) + return + self.outstanding_requests.remove(response.id) + + self.jsonrpc_stream.send(response) diff --git a/playpen/rpc/logs.py b/playpen/rpc/logs.py new file mode 100644 index 00000000..8e7e2f61 --- /dev/null +++ b/playpen/rpc/logs.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import logging +import sys +from typing import TYPE_CHECKING + +from playpen.rpc.util import get_logger + +if TYPE_CHECKING: + from playpen.rpc.core import JsonRpcServer + +log = get_logger("jsonrpc") + + +class JsonRpcLoggingHandler(logging.Handler): + def __init__(self, server: JsonRpcServer, method: str = "logMessage"): + logging.Handler.__init__(self) + self.server = server + self.method = method + + def emit(self, record: logging.LogRecord) -> None: + try: + params = { + "name": record.name, + "levelno": record.levelno, + "levelname": record.levelname, + "pathname": record.pathname, + "filename": record.filename, + "module": record.module, + "lineno": record.lineno, + "funcName": record.funcName, + "created": record.created, + "asctime": record.asctime, + "msecs": record.msecs, + "relativeCreated": record.relativeCreated, + "thread": record.thread, + "threadName": record.threadName, + "process": record.process, + "message": record.getMessage(), + } + + self.server.send_notification( + method=self.method, + params=params, + ) + except Exception: + print("Failed to log message", file=sys.stderr) + self.handleError(record) diff --git a/playpen/rpc/models.py b/playpen/rpc/models.py new file mode 100644 index 00000000..ce10f614 --- /dev/null +++ b/playpen/rpc/models.py @@ -0,0 +1,46 @@ +from enum import IntEnum +from typing import Any, Optional + +from pydantic import BaseModel + + +class JsonRpcErrorCode(IntEnum): + ParseError = -32700 + InvalidRequest = -32600 + MethodNotFound = -32601 + InvalidParams = -32602 + InternalError = -32603 + ServerErrorStart = -32099 + ServerErrorEnd = -32000 + ServerNotInitialized = -32002 + UnknownErrorCode = -32001 + + +JsonRpcId = Optional[str | int] + +# NOTE: dict must come before BaseModel in the union, otherwise Pydantic breaks +# when calling model_validate on JsonRpcResponse. +JsonRpcResult = dict | BaseModel # type: ignore[type-arg] + + +class JsonRpcError(BaseModel): + code: JsonRpcErrorCode | int + message: str + data: Optional[Any] = None + + +class JsonRpcResponse(BaseModel): + jsonrpc: str = "2.0" + result: Optional[JsonRpcResult] = None + error: Optional[JsonRpcError | str] = None + id: JsonRpcId = None + + +class JsonRpcRequest(BaseModel): + jsonrpc: str = "2.0" + method: str + params: Optional[dict[str, Any]] = None + id: JsonRpcId = None + + +# JsonRpcRequestResult = tuple[JsonRpcResult | None, JsonRpcError | None] diff --git a/playpen/rpc/streams.py b/playpen/rpc/streams.py new file mode 100644 index 00000000..fd5dbb40 --- /dev/null +++ b/playpen/rpc/streams.py @@ -0,0 +1,227 @@ +import json +import threading +from abc import ABC, abstractmethod +from io import BufferedReader, BufferedWriter +from typing import Any, Optional + +from pydantic import BaseModel + +from playpen.rpc.models import ( + JsonRpcError, + JsonRpcErrorCode, + JsonRpcRequest, + JsonRpcResponse, +) +from playpen.rpc.util import TRACE, get_logger + +log = get_logger("jsonrpc") + + +class JsonRpcStream(ABC): + """ + And abstract base class for a JSON-RPC stream. This class is used to + communicate with the JSON-RPC server and client. + """ + + def __init__( + self, + recv_file: BufferedReader, + send_file: BufferedWriter, + json_dumps_kwargs: Optional[dict[Any, Any]] = None, + json_loads_kwargs: Optional[dict[Any, Any]] = None, + ): + if json_dumps_kwargs is None: + json_dumps_kwargs = {} + if json_loads_kwargs is None: + json_loads_kwargs = {} + + self.recv_file = recv_file + self.recv_lock = threading.Lock() + + self.send_file = send_file + self.send_lock = threading.Lock() + + self.json_dumps_kwargs = json_dumps_kwargs + self.json_loads_kwargs = json_loads_kwargs + + def close(self) -> None: + self.recv_file.close() + self.send_file.close() + + @abstractmethod + def send(self, msg: JsonRpcRequest | JsonRpcResponse) -> None: ... + + @abstractmethod + def recv(self) -> JsonRpcError | JsonRpcRequest | JsonRpcResponse | None: ... + + +class LspStyleStream(JsonRpcStream): + """ + Standard LSP-style stream for JSON-RPC communication. This uses HTTP-style + headers for content length and content type. + """ + + JSON_RPC_REQ_FORMAT = "Content-Length: {json_string_len}\r\n\r\n{json_string}" + LEN_HEADER = "Content-Length: " + TYPE_HEADER = "Content-Type: " + + def send(self, msg: JsonRpcRequest | JsonRpcResponse) -> None: + json_str = msg.model_dump_json() + json_req = f"Content-Length: {len(json_str)}\r\n\r\n{json_str}" + + # Prevent infinite recursion + if isinstance(msg, JsonRpcRequest) and msg.method != "logMessage": + log.log(TRACE, "Sending request: %s", json_req) + + with self.send_lock: + self.send_file.write(json_req.encode()) + self.send_file.flush() + + def recv(self) -> JsonRpcError | JsonRpcRequest | JsonRpcResponse | None: + log.debug("Waiting for message") + + with self.recv_lock: + log.log(TRACE, "Reading headers") + content_length = -1 + + while True: + if self.recv_file.closed: + return None + + log.log(TRACE, "Reading header line") + line_bytes = self.recv_file.readline() + + log.log(TRACE, "Read header line: %s", line_bytes) + if not line_bytes: + return None + + line = line_bytes.decode("utf-8") + if not line.endswith("\r\n"): + return JsonRpcError( + code=JsonRpcErrorCode.ParseError, + message="Bad header: missing newline", + ) + + line = line[:-2] + + if line == "": + break + elif line.startswith(self.LEN_HEADER): + line = line[len(self.LEN_HEADER) :] + if not line.isdigit(): + return JsonRpcError( + code=JsonRpcErrorCode.ParseError, + message="Bad header: size is not int", + ) + content_length = int(line) + elif line.startswith(self.TYPE_HEADER): + pass + else: + return JsonRpcError( + code=JsonRpcErrorCode.ParseError, + message=f"Bad header: unknown header {line}", + ) + + if content_length < 0: + return JsonRpcError( + code=JsonRpcErrorCode.ParseError, + message="Bad header: missing Content-Length", + ) + + log.log(TRACE, "Got message with content length: %s", content_length) + + try: + msg_str = self.recv_file.read(content_length).decode("utf-8") + msg_dict = json.loads(msg_str, **self.json_loads_kwargs) + except Exception as e: + return JsonRpcError( + code=JsonRpcErrorCode.ParseError, + message=f"Invalid JSON: {e}", + ) + + log.log(TRACE, "Got message: %s", msg_dict) + + try: + if "method" in msg_dict: + return JsonRpcRequest.model_validate(msg_dict) + else: + return JsonRpcResponse.model_validate(msg_dict) + except Exception as e: + return JsonRpcError( + code=JsonRpcErrorCode.ParseError, + message=f"Could not validate JSON: {e}", + ) + + +class BareJsonStream(JsonRpcStream): + def __init__( + self, + recv_file: BufferedReader, + send_file: BufferedWriter, + json_dumps_kwargs: dict[Any, Any] | None = None, + json_loads_kwargs: dict[Any, Any] | None = None, + ): + super().__init__(recv_file, send_file, json_dumps_kwargs, json_loads_kwargs) + + self.buffer: str = "" + self.decoder = json.JSONDecoder() + self.chunk_size = 512 + + def send(self, msg: JsonRpcRequest | JsonRpcResponse) -> None: + json_req = msg.model_dump_json() + + # Prevent infinite recursion + if not isinstance(msg, JsonRpcRequest) or msg.method != "logMessage": + log.log(TRACE, "send: %s", json_req) + else: + log_msg = msg.model_copy() + if log_msg.params is None: + log_msg.params = {} + elif isinstance(log_msg.params, dict): + if "message" in log_msg.params: + log_msg.params["message"] = "" + elif isinstance(log_msg.params, BaseModel): + if hasattr(log_msg.params, "message"): + log_msg.params.message = "" + + log.log(TRACE, f"send: {log_msg.model_dump_json()}") + + with self.send_lock: + self.send_file.write(json_req.encode()) + self.send_file.flush() + + def get_from_buffer(self) -> JsonRpcError | JsonRpcRequest | JsonRpcResponse | None: + try: + msg, idx = self.decoder.raw_decode(self.buffer) + self.buffer = self.buffer[idx:] + + log.log(TRACE, "recv msg: %s", msg) + log.log(TRACE, "recv buffer: %s", self.buffer) + + if "method" in msg: + return JsonRpcRequest.model_validate(msg) + else: + return JsonRpcResponse.model_validate(msg) + except json.JSONDecodeError: + return None + except Exception as e: + return JsonRpcError( + code=JsonRpcErrorCode.ParseError, + message=f"Invalid JSON: {e}", + ) + + def recv(self) -> JsonRpcError | JsonRpcRequest | JsonRpcResponse | None: + with self.recv_lock: + result = self.get_from_buffer() + if result is not None: + return result + + while chunk := self.recv_file.read1(self.chunk_size): + self.buffer += chunk.decode("utf-8") + log.log(TRACE, "recv buffer: %s", self.buffer) + + result = self.get_from_buffer() + if result is not None: + return result + + return None diff --git a/playpen/rpc/util.py b/playpen/rpc/util.py new file mode 100644 index 00000000..5418b968 --- /dev/null +++ b/playpen/rpc/util.py @@ -0,0 +1,71 @@ +import logging +import sys +from typing import Any + +from pydantic import AliasChoices, AliasGenerator, BaseModel, ConfigDict +from pydantic.alias_generators import to_camel + +TRACE = logging.DEBUG - 5 +DEFAULT_FORMATTER = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) + + +def get_logger( + name: str, + stderr_level: int | str = "TRACE", + formatter: logging.Formatter = DEFAULT_FORMATTER, +) -> logging.Logger: + logging.addLevelName(logging.DEBUG - 5, "TRACE") + + logger = logging.getLogger(name) + + if logger.hasHandlers(): + return logger + + logger.setLevel(TRACE) + + stderr_handler = logging.StreamHandler(sys.stderr) + stderr_handler.setLevel(stderr_level) + stderr_handler.setFormatter(formatter) + logger.addHandler(stderr_handler) + + return logger + + +def log_record_to_dict(record: logging.LogRecord) -> dict[str, Any]: + return { + "name": record.name, + "levelno": record.levelno, + "levelname": record.levelname, + "pathname": record.pathname, + "filename": record.filename, + "module": record.module, + "lineno": record.lineno, + "funcName": record.funcName, + "created": record.created, + "asctime": record.asctime, + "msecs": record.msecs, + "relativeCreated": record.relativeCreated, + "thread": record.thread, + "threadName": record.threadName, + "process": record.process, + "msg": record.msg, + "args": record.args, + "message": record.getMessage(), + } + + +class CamelCaseBaseModel(BaseModel): + model_config = ConfigDict( + alias_generator=AliasGenerator( + validation_alias=lambda field_name: AliasChoices( + field_name, + to_camel(field_name), + ), + serialization_alias=to_camel, + ), + ) + + def model_dump(self, **kwargs: Any) -> dict[str, Any]: + return super().model_dump(by_alias=True, **kwargs) diff --git a/playpen/rpc_server/rpc.py b/playpen/rpc_server/rpc.py deleted file mode 100644 index ea134ac3..00000000 --- a/playpen/rpc_server/rpc.py +++ /dev/null @@ -1,784 +0,0 @@ -import functools -import json -import logging -import sys -import threading -from abc import ABC, abstractmethod -from enum import IntEnum -from io import BufferedReader, BufferedWriter -from typing import Any, Callable, Literal, Optional, cast, overload - -from pydantic import BaseModel, ConfigDict, validate_call - -TRACE = logging.DEBUG - 5 -DEFAULT_FORMATTER = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) - - -def get_logger( - name: str, - stderr_level: int | str = "TRACE", - formatter: logging.Formatter = DEFAULT_FORMATTER, -) -> logging.Logger: - logging.addLevelName(logging.DEBUG - 5, "TRACE") - - logger = logging.getLogger(name) - - if logger.hasHandlers(): - return logger - - logger.setLevel(TRACE) - - stderr_handler = logging.StreamHandler(sys.stderr) - stderr_handler.setLevel(stderr_level) - stderr_handler.setFormatter(formatter) - logger.addHandler(stderr_handler) - - return logger - - -def log_record_to_dict(record: logging.LogRecord) -> dict[str, Any]: - return { - "name": record.name, - "levelno": record.levelno, - "levelname": record.levelname, - "pathname": record.pathname, - "filename": record.filename, - "module": record.module, - "lineno": record.lineno, - "funcName": record.funcName, - "created": record.created, - "asctime": record.asctime, - "msecs": record.msecs, - "relativeCreated": record.relativeCreated, - "thread": record.thread, - "threadName": record.threadName, - "process": record.process, - "msg": record.msg, - "args": record.args, - "message": record.getMessage(), - } - - -log = get_logger("jsonrpc") - - -class JsonRpcErrorCode(IntEnum): - ParseError = -32700 - InvalidRequest = -32600 - MethodNotFound = -32601 - InvalidParams = -32602 - InternalError = -32603 - ServerErrorStart = -32099 - ServerErrorEnd = -32000 - ServerNotInitialized = -32002 - UnknownErrorCode = -32001 - - -JsonRpcId = Optional[str | int] - -JsonRpcResult = BaseModel | dict - - -class JsonRpcError(BaseModel): - code: JsonRpcErrorCode | int - message: str - data: Optional[Any] = None - - -class JsonRpcResponse(BaseModel): - jsonrpc: str = "2.0" - result: Optional[JsonRpcResult] = None - error: Optional[JsonRpcError] = None - id: JsonRpcId = None - - -class JsonRpcRequest(BaseModel): - jsonrpc: str = "2.0" - method: str - params: Optional[dict] = None - id: JsonRpcId = None - - -class JsonRpcStream(ABC): - """ - And abstract base class for a JSON-RPC stream. This class is used to - communicate with the JSON-RPC server and client. - """ - - def __init__( - self, - recv_file: BufferedReader, - send_file: BufferedWriter, - json_dumps_kwargs: Optional[dict] = None, - json_loads_kwargs: Optional[dict] = None, - ): - if json_dumps_kwargs is None: - json_dumps_kwargs = {} - if json_loads_kwargs is None: - json_loads_kwargs = {} - - self.recv_file = recv_file - self.recv_lock = threading.Lock() - - self.send_file = send_file - self.send_lock = threading.Lock() - - self.json_dumps_kwargs = json_dumps_kwargs - self.json_loads_kwargs = json_loads_kwargs - - def close(self) -> None: - self.recv_file.close() - self.send_file.close() - - @abstractmethod - def send(self, msg: JsonRpcRequest | JsonRpcResponse) -> None: ... - - @abstractmethod - def recv(self) -> JsonRpcError | JsonRpcRequest | JsonRpcResponse | None: ... - - -class LspStyleStream(JsonRpcStream): - """ - Standard LSP-style stream for JSON-RPC communication. This uses HTTP-style - headers for content length and content type. - """ - - JSON_RPC_REQ_FORMAT = "Content-Length: {json_string_len}\r\n\r\n{json_string}" - LEN_HEADER = "Content-Length: " - TYPE_HEADER = "Content-Type: " - - def send(self, msg: JsonRpcRequest | JsonRpcResponse) -> None: - json_str = msg.model_dump_json() - json_req = f"Content-Length: {len(json_str)}\r\n\r\n{json_str}" - - # Prevent infinite recursion - if isinstance(msg, JsonRpcRequest) and msg.method != "logMessage": - log.log(TRACE, "Sending request: %s", json_req) - - with self.send_lock: - self.send_file.write(json_req.encode()) - self.send_file.flush() - - def recv(self) -> JsonRpcError | JsonRpcRequest | JsonRpcResponse | None: - log.debug("Waiting for message") - - with self.recv_lock: - log.log(TRACE, "Reading headers") - content_length = -1 - - while True: - if self.recv_file.closed: - return None - - log.log(TRACE, "Reading header line") - line_bytes = self.recv_file.readline() - - log.log(TRACE, "Read header line: %s", line_bytes) - if not line_bytes: - return None - - line = line_bytes.decode("utf-8") - if not line.endswith("\r\n"): - return JsonRpcError( - code=JsonRpcErrorCode.ParseError, - message="Bad header: missing newline", - ) - - line = line[:-2] - - if line == "": - break - elif line.startswith(self.LEN_HEADER): - line = line[len(self.LEN_HEADER) :] - if not line.isdigit(): - return JsonRpcError( - code=JsonRpcErrorCode.ParseError, - message="Bad header: size is not int", - ) - content_length = int(line) - elif line.startswith(self.TYPE_HEADER): - pass - else: - return JsonRpcError( - code=JsonRpcErrorCode.ParseError, - message=f"Bad header: unknown header {line}", - ) - - if content_length < 0: - return JsonRpcError( - code=JsonRpcErrorCode.ParseError, - message="Bad header: missing Content-Length", - ) - - log.log(TRACE, "Got message with content length: %s", content_length) - - try: - msg_str = self.recv_file.read(content_length).decode("utf-8") - msg_dict = json.loads(msg_str, **self.json_loads_kwargs) - except Exception as e: - return JsonRpcError( - code=JsonRpcErrorCode.ParseError, - message=f"Invalid JSON: {e}", - ) - - log.log(TRACE, "Got message: %s", msg_dict) - - try: - if "method" in msg_dict: - return JsonRpcRequest.model_validate(msg_dict) - else: - return JsonRpcResponse.model_validate(msg_dict) - except Exception as e: - return JsonRpcError( - code=JsonRpcErrorCode.ParseError, - message=f"Could not validate JSON: {e}", - ) - - -class BareJsonStream(JsonRpcStream): - def __init__( - self, - recv_file: BufferedReader, - send_file: BufferedWriter, - json_dumps_kwargs: dict | None = None, - json_loads_kwargs: dict | None = None, - ): - super().__init__(recv_file, send_file, json_dumps_kwargs, json_loads_kwargs) - - self.buffer: str = "" - self.decoder = json.JSONDecoder() - self.chunk_size = 512 - - def send(self, msg: JsonRpcRequest | JsonRpcResponse) -> None: - json_req = msg.model_dump_json() - - # Prevent infinite recursion - if not isinstance(msg, JsonRpcRequest) or msg.method != "logMessage": - log.log(TRACE, "send: %s", json_req) - else: - log_msg = msg.model_copy() - if log_msg.params is None: - log_msg.params = {} - if "message" in log_msg.params: - log_msg.params["message"] = "" - log.log(TRACE, f"send: {log_msg.model_dump_json()}") - - with self.send_lock: - self.send_file.write(json_req.encode()) - self.send_file.flush() - - def get_from_buffer(self) -> JsonRpcError | JsonRpcRequest | JsonRpcResponse | None: - try: - msg, idx = self.decoder.raw_decode(self.buffer) - self.buffer = self.buffer[idx:] - - log.log(TRACE, "recv msg: %s", msg) - log.log(TRACE, "recv buffer: %s", self.buffer) - - if "method" in msg: - return JsonRpcRequest.model_validate(msg) - else: - return JsonRpcResponse.model_validate(msg) - except json.JSONDecodeError: - return None - except Exception as e: - return JsonRpcError( - code=JsonRpcErrorCode.ParseError, - message=f"Invalid JSON: {e}", - ) - - def recv(self) -> JsonRpcError | JsonRpcRequest | JsonRpcResponse | None: - with self.recv_lock: - result = self.get_from_buffer() - if result is not None: - return result - - while chunk := self.recv_file.read1(self.chunk_size): - self.buffer += chunk.decode("utf-8") - log.log(TRACE, "recv buffer: %s", self.buffer) - - result = self.get_from_buffer() - if result is not None: - return result - - return None - - -JsonRpcRequestResult = tuple[JsonRpcResult | None, JsonRpcError | None] - -JsonRpcRequestCallable = Callable[..., JsonRpcRequestResult] -JsonRpcNotifyCallable = Callable[..., None] - - -class JsonRpcCallback: - """ - A JsonRpcCallback is a wrapper around a JsonRpcMethodCallable or - JsonRpcNotifyCallable. It validates the parameters and calls the function. - - We use this class to allow for more flexibility in the parameters that can - be passed to the function. - """ - - def __init__( - self, - func: JsonRpcRequestCallable | JsonRpcNotifyCallable, - include_server: bool, - include_app: bool, - kind: Literal["request", "notify"], - method: str, - params_model: type[JsonRpcResult] | None = None, - ): - """ - If params_model is not supplied, the schema will be generated from the - function arguments. - """ - - self.func = func - self.params_model = params_model - self.include_server = include_server - self.include_app = include_app - self.kind = kind - self.method = method - - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) - @functools.wraps(self.func) - def validate_func_args( - *args: Any, **kwargs: Any - ) -> tuple[tuple[Any, ...], dict[str, Any]]: - return args, kwargs - - self.validate_func_args = validate_func_args - - # TODO: Generate docs - # See https://github.com/konveyor/kai/blob/d7727a8113f185ee393d368a924da7c9deedd45b/playpen/rpc_server/rpc.py#L354 - - def __call__( - self, - params: dict | None, - server: Optional["JsonRpcServer"], - app: Optional["JsonRpcApplication"], - ) -> tuple[JsonRpcResult | None, JsonRpcError | None] | None: - if params is None: - params = {} - - kwargs = {} - - if self.params_model is None: - kwargs = params.copy() - elif self.params_model is dict: - kwargs["params"] = params - else: - try: - kwargs["params"] = cast( - type[BaseModel], self.params_model - ).model_validate(params) - except Exception as e: - return None, JsonRpcError( - code=JsonRpcErrorCode.InvalidParams, - message=f"Invalid parameters: {e}", - ) - - if self.include_server: - kwargs["server"] = server - if self.include_app: - kwargs["app"] = app - - try: - self.validate_func_args(**kwargs) - return self.func(**kwargs) - except Exception as e: - return None, JsonRpcError( - code=JsonRpcErrorCode.InvalidParams, - message=f"Invalid parameters: {e}", - ) - - -class JsonRpcApplication: - """ - Taking a page out of the ASGI standards, JsonRpcApplication is a collection - of JsonRpcCallbacks that can be used to handle incoming requests and - notifications. - """ - - def __init__( - self, - request_callbacks: Optional[dict[str, JsonRpcCallback]] = None, - notify_callbacks: Optional[dict[str, JsonRpcCallback]] = None, - ): - if request_callbacks is None: - request_callbacks = {} - if notify_callbacks is None: - notify_callbacks = {} - - self.request_callbacks = request_callbacks - self.notify_callbacks = notify_callbacks - - def handle_request( - self, msg: JsonRpcRequest, server: "JsonRpcServer" - ) -> tuple[JsonRpcResult | None, JsonRpcError | None] | None: - log.log(TRACE, "Handling request: %s", msg) - - if msg.id is not None: - log.log(TRACE, "Request is a bona fide request") - # bona fide request - if msg.method not in self.request_callbacks: - return None, JsonRpcError( - code=JsonRpcErrorCode.MethodNotFound, - message=f"Method not found: {msg.method}", - ) - - log.log(TRACE, "Calling method: %s", msg.method) - - result = self.request_callbacks[msg.method]( - params=msg.params, server=server, app=self - ) - - err = JsonRpcError( - code=JsonRpcErrorCode.InternalError, - message="Method did not return a tuple[JsonRpcResult | None, JsonRpcError | None]", - ) - - if not isinstance(result, tuple) or len(result) != 2: - return None, err - # NOTE: If we ever narrow down JsonRpcResult from Any, we should - # re-enable this check. - if not isinstance(result[0], (type(None), JsonRpcResult)): - return None, err - if not isinstance(result[1], (type(None), JsonRpcError)): - return None, err - - return result - - else: - log.log(TRACE, "Request is a notification") - - # notification - if msg.method not in self.notify_callbacks: - log.error(f"Notify method not found: {msg.method}") - return None - - log.log(TRACE, "Calling method: %s", msg.method) - - result = self.notify_callbacks[msg.method]( - params=msg.params, server=server, app=self - ) - - if result is not None: - return None, JsonRpcError( - code=JsonRpcErrorCode.InternalError, - message="Notification did not return None", - ) - - return None - - @overload - def add( - self, - func: JsonRpcRequestCallable, - *, - kind: Literal["request"] = "request", - method: str | None = ..., - params_model: type[JsonRpcResult] | None = ..., - include_server: bool = ..., - include_app: bool = ..., - ) -> JsonRpcCallback: ... - - @overload - def add( - self, - func: JsonRpcNotifyCallable, - *, - kind: Literal["notify"] = "notify", - method: str | None = ..., - params_model: type[JsonRpcResult] | None = ..., - include_server: bool = ..., - include_app: bool = ..., - ) -> JsonRpcCallback: ... - - @overload - def add( - self, - func: None = ..., - *, - kind: Literal["request"] = "request", - method: str | None = ..., - params_model: type[JsonRpcResult] | None = ..., - include_server: bool = ..., - include_app: bool = ..., - ) -> Callable[[JsonRpcRequestCallable], JsonRpcCallback]: ... - - @overload - def add( - self, - func: None = ..., - *, - kind: Literal["notify"] = "notify", - method: str | None = ..., - params_model: type[JsonRpcResult] | None = ..., - include_server: bool = ..., - include_app: bool = ..., - ) -> Callable[[JsonRpcNotifyCallable], JsonRpcCallback]: ... - - def add( - self, - func: JsonRpcRequestCallable | JsonRpcNotifyCallable | None = None, - *, - kind: Literal["request", "notify"] = "request", - method: str | None = None, - params_model: type[JsonRpcResult] | None = None, - include_server: bool = False, - include_app: bool = True, - ) -> ( - JsonRpcCallback - | Callable[[JsonRpcRequestCallable], JsonRpcCallback] - | Callable[[JsonRpcNotifyCallable], JsonRpcCallback] - ): - if method is None: - raise ValueError("Method name must be provided") - - if kind == "request": - callbacks = self.request_callbacks - else: - callbacks = self.notify_callbacks - - def decorator( - func: JsonRpcRequestCallable | JsonRpcNotifyCallable, - ) -> JsonRpcCallback: - callback = JsonRpcCallback( - func=func, - include_server=include_server, - include_app=include_app, - kind=kind, - method=method, - params_model=params_model, - ) - callbacks[method] = callback - - return callback - - if func: - return decorator(func) - else: - return decorator - - @overload - def add_notify( - self, - func: JsonRpcNotifyCallable, - *, - method: str | None = ..., - params_model: type[JsonRpcResult] | None = ..., - include_server: bool = ..., - include_app: bool = ..., - ) -> JsonRpcCallback: ... - - @overload - def add_notify( - self, - func: None = ..., - *, - method: str | None = ..., - params_model: type[JsonRpcResult] | None = ..., - include_server: bool = ..., - include_app: bool = ..., - ) -> Callable[[JsonRpcNotifyCallable], JsonRpcCallback]: ... - - def add_notify( - self, - func: JsonRpcNotifyCallable | None = None, - *, - method: str | None = None, - params_model: type[JsonRpcResult] | None = None, - include_server: bool = False, - include_app: bool = True, - ) -> JsonRpcCallback | Callable[[JsonRpcNotifyCallable], JsonRpcCallback]: - return self.add( - func=func, - kind="notify", - method=method, - params_model=params_model, - include_server=include_server, - include_app=include_app, - ) - - @overload - def add_request( - self, - func: JsonRpcRequestCallable, - *, - method: str | None = ..., - params_model: type[JsonRpcResult] | None = ..., - include_server: bool = ..., - include_app: bool = ..., - ) -> JsonRpcCallback: ... - - @overload - def add_request( - self, - func: None = ..., - *, - method: str | None = ..., - params_model: type[JsonRpcResult] | None = ..., - include_server: bool = ..., - include_app: bool = ..., - ) -> Callable[[JsonRpcRequestCallable], JsonRpcCallback]: ... - - def add_request( - self, - func: JsonRpcRequestCallable | None = None, - *, - method: str | None = None, - params_model: type[JsonRpcResult] | None = None, - include_server: bool = False, - include_app: bool = True, - ) -> JsonRpcCallback | Callable[[JsonRpcRequestCallable], JsonRpcCallback]: - return self.add( - func=func, - kind="request", - method=method, - params_model=params_model, - include_server=include_server, - include_app=include_app, - ) - - def generate_docs(self) -> str: - raise NotImplementedError() - - -class JsonRpcServer(threading.Thread): - """ - Taking a page from Python's ASGI standards, JsonRpcServer serves a - JsonRpcApplication. It is a thread that listens for incoming requests and - notifications on a JsonRpcStream, and sends responses over the same stream. - - We separate the two classes to allow one to define routes in different - files. - - Despite being called "server", you can also use this as a client. - """ - - def __init__( - self, - json_rpc_stream: JsonRpcStream, - app: JsonRpcApplication | None = None, - request_timeout: float = 60.0, - ): - if app is None: - app = JsonRpcApplication() - - threading.Thread.__init__(self) - - self.jsonrpc_stream = json_rpc_stream - self.app = app - - self.event_dict: dict[JsonRpcId, threading.Condition] = {} - self.response_dict: dict[JsonRpcId, JsonRpcResponse] = {} - self.next_id = 0 - self.request_timeout = request_timeout - - self.shutdown_flag = False - - def stop(self) -> None: - self.shutdown_flag = True - - def run(self) -> None: - log.debug("Server thread started") - - while not self.shutdown_flag: - msg = self.jsonrpc_stream.recv() - if msg is None: - log.info("Server quit") - break - - elif isinstance(msg, JsonRpcError): - self.jsonrpc_stream.send(JsonRpcResponse(error=msg)) - continue - - elif isinstance(msg, JsonRpcRequest): - log.log(TRACE, "Received request: %s", msg) - if (tmp := self.app.handle_request(msg, self)) is not None: - log.log(TRACE, "Sending response: %s", tmp) - result, error = tmp - self.jsonrpc_stream.send( - JsonRpcResponse(result=result, error=error, id=msg.id) - ) - continue - - elif isinstance(msg, JsonRpcResponse): - self.response_dict[msg.id] = msg - cond = self.event_dict[msg.id] - cond.acquire() - cond.notify() - cond.release() - - else: - log.error(f"Unknown message type: {type(msg)}") - - self.jsonrpc_stream.close() - - def send_request( - self, method: str, params: dict[str, Any] - ) -> JsonRpcResponse | JsonRpcError | None: - log.log(TRACE, "Sending request: %s", method) - current_id = self.next_id - self.next_id += 1 - cond = threading.Condition() - self.event_dict[current_id] = cond - - cond.acquire() - self.jsonrpc_stream.send( - JsonRpcRequest(method=method, params=params, id=current_id) - ) - - if self.shutdown_flag: - cond.release() - return None - - if not cond.wait(self.request_timeout): - cond.release() - return JsonRpcError( - code=JsonRpcErrorCode.InternalError, - message="Timeout waiting for response", - ) - cond.release() - - self.event_dict.pop(current_id) - return self.response_dict.pop(current_id) - - def send_notification(self, method: str, params: dict[str, Any]) -> None: - self.jsonrpc_stream.send(JsonRpcRequest(method=method, params=params)) - - -class JsonRpcLoggingHandler(logging.Handler): - def __init__(self, server: JsonRpcServer, method: str = "logMessage"): - logging.Handler.__init__(self) - self.server = server - self.method = method - - def emit(self, record: logging.LogRecord) -> None: - try: - params = { - "name": record.name, - "levelno": record.levelno, - "levelname": record.levelname, - "pathname": record.pathname, - "filename": record.filename, - "module": record.module, - "lineno": record.lineno, - "funcName": record.funcName, - "created": record.created, - "asctime": record.asctime, - "msecs": record.msecs, - "relativeCreated": record.relativeCreated, - "thread": record.thread, - "threadName": record.threadName, - "process": record.process, - "message": record.getMessage(), - } - - self.server.send_notification(self.method, params) - except Exception: - print("Failed to log message", file=sys.stderr) - self.handleError(record) diff --git a/playpen/rpc_server/server.py b/playpen/rpc_server/server.py deleted file mode 100644 index 2e51226f..00000000 --- a/playpen/rpc_server/server.py +++ /dev/null @@ -1,154 +0,0 @@ -import logging -import sys -from pathlib import Path -from typing import Optional, cast - -from pydantic import BaseModel - -from kai.models.kai_config import KaiConfigModels -from playpen.rpc_server.rpc import ( - DEFAULT_FORMATTER, - TRACE, - JsonRpcApplication, - JsonRpcError, - JsonRpcErrorCode, - JsonRpcLoggingHandler, - JsonRpcRequestResult, - JsonRpcServer, -) - -log = logging.getLogger(__name__) - - -class KaiRpcApplicationConfig(BaseModel): - processId: Optional[int] - - rootUri: str - kantraUri: str - modelProvider: KaiConfigModels - kaiBackendUrl: str - - logLevel: str = "INFO" - stderrLogLevel: str = "TRACE" - fileLogLevel: Optional[str] = None - logDirUri: Optional[str] = None - - -class KaiRpcApplication(JsonRpcApplication): - def __init__(self) -> None: - super().__init__() - - self.initialized = False - self.config: Optional[KaiRpcApplicationConfig] = None - self.log = logging.getLogger("kai_rpc_application") - - -app = KaiRpcApplication() - -ERROR_NOT_INITIALIZED = JsonRpcError( - code=JsonRpcErrorCode.ServerErrorStart, - message="Server not initialized", -) - - -@app.add_request(method="shutdown", include_server=True) -def shutdown(app: KaiRpcApplication, server: JsonRpcServer) -> tuple[dict, None]: - server.shutdown_flag = True - - return {}, None - - -@app.add_request(method="exit", include_server=True) -def exit(app: KaiRpcApplication, server: JsonRpcServer) -> JsonRpcRequestResult: - server.shutdown_flag = True - - return {}, None - - -@app.add_request( - method="initialize", params_model=KaiRpcApplicationConfig, include_server=True -) -def initialize( - app: KaiRpcApplication, params: KaiRpcApplicationConfig, server: JsonRpcServer -) -> JsonRpcRequestResult: - if app.initialized: - return {}, JsonRpcError( - code=JsonRpcErrorCode.ServerErrorStart, - message="Server already initialized", - ) - - try: - app.config = params - - app.log.setLevel(TRACE) - app.log.handlers.clear() - app.log.filters.clear() - - stderr_handler = logging.StreamHandler(sys.stderr) - stderr_handler.setLevel(TRACE) - stderr_handler.setFormatter(DEFAULT_FORMATTER) - app.log.addHandler(stderr_handler) - - notify_handler = JsonRpcLoggingHandler(server) - notify_handler.setLevel(app.config.logLevel) - notify_handler.setFormatter(DEFAULT_FORMATTER) - app.log.addHandler(notify_handler) - - if app.config.fileLogLevel and app.config.logDirUri: - log_dir = Path(app.config.logDirUri) # FIXME: urlparse? - log_file = log_dir / "kai_rpc.log" - log_file.parent.mkdir(parents=True, exist_ok=True) - - file_handler = logging.FileHandler(log_file) - file_handler.setLevel(app.config.fileLogLevel) - file_handler.setFormatter(DEFAULT_FORMATTER) - app.log.addHandler(file_handler) - - app.log.info(f"Initialized with config: {app.config}") - - except Exception as e: - return {}, JsonRpcError( - code=JsonRpcErrorCode.InvalidParams, - message=str(e), - ) - - app.initialized = True - - return {}, None - - -@app.add_request(method="setConfig", params_model=dict, include_server=True) -def set_config( - app: KaiRpcApplication, params: dict, server: JsonRpcServer -) -> JsonRpcRequestResult: - if not app.initialized: - return {}, ERROR_NOT_INITIALIZED - - # Basically reset everything - app.initialized = False - return cast(JsonRpcRequestResult, initialize(app=app, params=params, server=server)) - - -@app.add_request(method="getRAGSolution") -def get_rag_solution(app: KaiRpcApplication) -> JsonRpcRequestResult: - if not app.initialized: - return {}, ERROR_NOT_INITIALIZED - - return {}, None - - -@app.add_request(method="getCodeplanAgentSolution") -def get_codeplan_agent_solution(app: KaiRpcApplication) -> JsonRpcRequestResult: - if not app.initialized: - return {}, ERROR_NOT_INITIALIZED - - return {}, None - - -# if __name__ == "__main__": -# # with __import__("ipdb").launch_ipdb_on_exception(): -# file_path = Path(__file__).resolve() -# docs_path = file_path.parent / "docs.md" -# print(docs_path) -# with open(str(docs_path), "w") as f: -# f.write(app.generate_docs()) diff --git a/pyproject.toml b/pyproject.toml index 389158c3..ea096fcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ # --- Testing dependencies --- "coverage==7.6.0", + "ipdb==0.13.13", # --- For notebook development --- "jupyter==1.0.0", @@ -62,6 +63,8 @@ dependencies = [ "typer==0.9.0", # For potential CLI stuff "loguru==0.7.2", # For potential logging improvements "unidiff==0.7.5", + + "imgui[sdl2]", ] requires-python = ">=3.11" authors = [