From a65c29adc027b3615154cab73aaedd58a6aa23da Mon Sep 17 00:00:00 2001
From: "Jason R. Coombs" <jaraco@jaraco.com>
Date: Tue, 23 Jul 2024 08:36:16 -0400
Subject: [PATCH] Prioritize valid dists to invalid dists when retrieving by
 name.

Closes python/importlib_metadata#489

Upstream-Status: Backport [https://github.com/python/importlib_metadata/commit/a65c29adc027b3615154cab73aaedd58a6aa23da]
Signed-off-by: Ross Burton <ross.burton@arm.com>
---
 Lib/importlib/metadata/__init__.py   | 14 +++-
 Lib/importlib/metadata/_itertools.py | 98 ++++++++++++++++++++++++++++
 2 files changed, 110 insertions(+), 2 deletions(-)

diff --git a/Lib/importlib/metadata/__init__.py b/Lib/importlib/metadata/__init__.py
index 8ce62dd..085378c 100644
--- a/Lib/importlib/metadata/__init__.py
+++ b/Lib/importlib/metadata/__init__.py
@@ -21,7 +21,7 @@ import collections
 from . import _meta
 from ._collections import FreezableDefaultDict, Pair
 from ._functools import method_cache, pass_none
-from ._itertools import always_iterable, unique_everseen
+from ._itertools import always_iterable, bucket, unique_everseen
 from ._meta import PackageMetadata, SimplePath
 
 from contextlib import suppress
@@ -404,7 +404,7 @@ class Distribution(DeprecatedNonAbstract):
         if not name:
             raise ValueError("A distribution name is required.")
         try:
-            return next(iter(cls.discover(name=name)))
+            return next(iter(cls._prefer_valid(cls.discover(name=name))))
         except StopIteration:
             raise PackageNotFoundError(name)
 
@@ -428,6 +428,16 @@ class Distribution(DeprecatedNonAbstract):
             resolver(context) for resolver in cls._discover_resolvers()
         )
 
+    @staticmethod
+    def _prefer_valid(dists: Iterable[Distribution]) -> Iterable[Distribution]:
+        """
+        Prefer (move to the front) distributions that have metadata.
+
+        Ref python/importlib_resources#489.
+        """
+        buckets = bucket(dists, lambda dist: bool(dist.metadata))
+        return itertools.chain(buckets[True], buckets[False])
+
     @staticmethod
     def at(path: str | os.PathLike[str]) -> Distribution:
         """Return a Distribution for the indicated metadata path.
diff --git a/Lib/importlib/metadata/_itertools.py b/Lib/importlib/metadata/_itertools.py
index d4ca9b9..79d3719 100644
--- a/Lib/importlib/metadata/_itertools.py
+++ b/Lib/importlib/metadata/_itertools.py
@@ -1,3 +1,4 @@
+from collections import defaultdict, deque
 from itertools import filterfalse
 
 
@@ -71,3 +72,100 @@ def always_iterable(obj, base_type=(str, bytes)):
         return iter(obj)
     except TypeError:
         return iter((obj,))
+
+
+# Copied from more_itertools 10.3
+class bucket:
+    """Wrap *iterable* and return an object that buckets the iterable into
+    child iterables based on a *key* function.
+
+        >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
+        >>> s = bucket(iterable, key=lambda x: x[0])  # Bucket by 1st character
+        >>> sorted(list(s))  # Get the keys
+        ['a', 'b', 'c']
+        >>> a_iterable = s['a']
+        >>> next(a_iterable)
+        'a1'
+        >>> next(a_iterable)
+        'a2'
+        >>> list(s['b'])
+        ['b1', 'b2', 'b3']
+
+    The original iterable will be advanced and its items will be cached until
+    they are used by the child iterables. This may require significant storage.
+
+    By default, attempting to select a bucket to which no items belong  will
+    exhaust the iterable and cache all values.
+    If you specify a *validator* function, selected buckets will instead be
+    checked against it.
+
+        >>> from itertools import count
+        >>> it = count(1, 2)  # Infinite sequence of odd numbers
+        >>> key = lambda x: x % 10  # Bucket by last digit
+        >>> validator = lambda x: x in {1, 3, 5, 7, 9}  # Odd digits only
+        >>> s = bucket(it, key=key, validator=validator)
+        >>> 2 in s
+        False
+        >>> list(s[2])
+        []
+
+    """
+
+    def __init__(self, iterable, key, validator=None):
+        self._it = iter(iterable)
+        self._key = key
+        self._cache = defaultdict(deque)
+        self._validator = validator or (lambda x: True)
+
+    def __contains__(self, value):
+        if not self._validator(value):
+            return False
+
+        try:
+            item = next(self[value])
+        except StopIteration:
+            return False
+        else:
+            self._cache[value].appendleft(item)
+
+        return True
+
+    def _get_values(self, value):
+        """
+        Helper to yield items from the parent iterator that match *value*.
+        Items that don't match are stored in the local cache as they
+        are encountered.
+        """
+        while True:
+            # If we've cached some items that match the target value, emit
+            # the first one and evict it from the cache.
+            if self._cache[value]:
+                yield self._cache[value].popleft()
+            # Otherwise we need to advance the parent iterator to search for
+            # a matching item, caching the rest.
+            else:
+                while True:
+                    try:
+                        item = next(self._it)
+                    except StopIteration:
+                        return
+                    item_value = self._key(item)
+                    if item_value == value:
+                        yield item
+                        break
+                    elif self._validator(item_value):
+                        self._cache[item_value].append(item)
+
+    def __iter__(self):
+        for item in self._it:
+            item_value = self._key(item)
+            if self._validator(item_value):
+                self._cache[item_value].append(item)
+
+        yield from self._cache.keys()
+
+    def __getitem__(self, value):
+        if not self._validator(value):
+            return iter(())
+
+        return self._get_values(value)
