Skip to content

Commit

Permalink
Fix Arguments.arguments so it actually returns all arguments
Browse files Browse the repository at this point in the history
Closes #2213.

Arguments.arguments() has been modified so that it returns all arguments
as it should (according to its own doc). A test case was also added to
verify this.
  • Loading branch information
crazybolillo committed Jul 11, 2023
1 parent a7ab088 commit 6a4b8f7
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 16 deletions.
26 changes: 20 additions & 6 deletions astroid/nodes/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,18 @@ def fromlineno(self) -> int:
@cached_property
def arguments(self):
"""Get all the arguments for this node, including positional only and positional and keyword"""
return list(itertools.chain((self.posonlyargs or ()), self.args or ()))
retval = list(itertools.chain((self.posonlyargs or ()), (self.args or ())))
if self.vararg:
retval.append(
Name(self.vararg, -1, -1, self, end_lineno=None, end_col_offset=None)
)
retval += self.kwonlyargs or ()
if self.kwarg:
retval.append(
Name(self.kwarg, -1, -1, self, end_lineno=None, end_col_offset=None)
)

return retval

def format_args(self, *, skippable_names: set[str] | None = None) -> str:
"""Get the arguments formatted as string.
Expand Down Expand Up @@ -910,15 +921,16 @@ def default_value(self, argname):
:raises NoDefault: If there is no default value defined for the
given argument.
"""
args = self.arguments
# Ignore *args and **kwargs
args = list(filter(lambda x: not isinstance(x, Name), self.arguments))
index = _find_arg(argname, self.kwonlyargs)[0]
if index is not None and self.kw_defaults[index] is not None:
return self.kw_defaults[index]
index = _find_arg(argname, args)[0]
if index is not None:
idx = index - (len(args) - len(self.defaults))
if idx >= 0:
return self.defaults[idx]
index = _find_arg(argname, self.kwonlyargs)[0]
if index is not None and self.kw_defaults[index] is not None:
return self.kw_defaults[index]
raise NoDefault(func=self.parent, name=argname)

def is_argument(self, name) -> bool:
Expand Down Expand Up @@ -955,7 +967,9 @@ def find_argname(self, argname, rec=DEPRECATED_ARGUMENT_DEFAULT):
stacklevel=2,
)
if self.arguments:
return _find_arg(argname, self.arguments)
index, argument = _find_arg(argname, self.arguments)
if not isinstance(argument, Name):
return index, argument
return None, None

def get_children(self):
Expand Down
12 changes: 2 additions & 10 deletions astroid/nodes/scoped_nodes/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,11 +963,7 @@ def argnames(self) -> list[str]:
names = [elt.name for elt in self.args.arguments]
else:
names = []
if self.args.vararg:
names.append(self.args.vararg)
names += [elt.name for elt in self.args.kwonlyargs]
if self.args.kwarg:
names.append(self.args.kwarg)

return names

def infer_call_result(
Expand Down Expand Up @@ -1280,11 +1276,7 @@ def argnames(self) -> list[str]:
names = [elt.name for elt in self.args.arguments]
else:
names = []
if self.args.vararg:
names.append(self.args.vararg)
names += [elt.name for elt in self.args.kwonlyargs]
if self.args.kwarg:
names.append(self.args.kwarg)

return names

def getattr(
Expand Down
26 changes: 26 additions & 0 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Uninferable,
bases,
builder,
extract_node,
nodes,
parse,
test_utils,
Expand Down Expand Up @@ -1943,3 +1944,28 @@ def test_str_repr_no_warnings(node):
test_node = node(**args)
str(test_node)
repr(test_node)


def test_arguments_contains_all():
"""Ensure Arguments.arguments actually returns all available arguments"""

def manually_get_args(arg_node) -> set:
names = set()
if arg_node.args.vararg:
names.add(arg_node.args.vararg)
if arg_node.args.kwarg:
names.add(arg_node.args.kwarg)

names.update([x.name for x in arg_node.args.args])
names.update([x.name for x in arg_node.args.kwonlyargs])

return names

node = extract_node("""def a(fruit: str, *args, b=None, c=None, **kwargs): ...""")
assert manually_get_args(node) == {x.name for x in node.args.arguments}

node = extract_node("""def a(mango: int, b="banana", c=None, **kwargs): ...""")
assert manually_get_args(node) == {x.name for x in node.args.arguments}

node = extract_node("""def a(self, num = 10, *args): ...""")
assert manually_get_args(node) == {x.name for x in node.args.arguments}

0 comments on commit 6a4b8f7

Please sign in to comment.