Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[py-tx] Implement a new cleaner PDQ index solution from scratch #1613

Open
Dcallies opened this issue Aug 21, 2024 · 1 comment
Open

[py-tx] Implement a new cleaner PDQ index solution from scratch #1613

Dcallies opened this issue Aug 21, 2024 · 1 comment
Labels
help wanted pdq Items related to the pdq libraries or reference implementations python-threatexchange Items related to the threatexchange python tool / library

Comments

@Dcallies
Copy link
Contributor

When we built the PDQ index, it was our first attempt, and we made a lot of strange/bad choices.

Namely:

  • The hash -> id mapping is extremely convoluted and needlessly complicated
  • The build and lookup implementations could be simplified

I think we could provide a second implementation that is a lot simpler, which we could then find a way to swap.

They key elements:

Pass in the index type as an argument during construction

    def __init__(
        self,
        threshold: int = DEFAULT_MATCH_DIST,
        faiss_index: t.Optional[faiss.Index] = None,
    ) -> None:

Simplify the stored state of the index implementation

# Body of __init__
       super().__init__()
        if faiss_index is None:
            # Brute force
            faiss_index = faiss.IndexFlatL2(DIMENSIONALITY)
        self.faiss_index = _PDQHashIndex(faiss_index)
        self.threshold = threshold
        self._deduper = {}
        self._idx_to_entries: t.List[t.List[T]] = []

Use a simpler inner wrapper to handle some of the PDQ details

class _PDQHashIndex:
"""
A wrapper around the faiss index for pickle serialization
"""

def __init__(self, faiss_index: faiss.Index) -> None:
    self.faiss_index = faiss_index

def search(
    self,
    queries: t.Sequence[str],
    threshhold: int,
) -> t.List[t.Tuple[int, float]]:
    """
    Search method that return a mapping from query_str =>  (id, distance)
    """
    qs = convert_pdq_strings_to_ndarray(queries)
    # in Python, the results are returned as a triplet of 1D arrays lims, D, I
    # where result for query i is in I[lims[i]:lims[i+1]] (indices of neighbors)
    # D[lims[i]:lims[i+1]] (distances).
    limits, D, I = self.faiss_index.range_search(qs, threshhold + 1)

    results = []
    for i in range(len(queries)):
        matches = [result.item() for result in I[limits[i] : limits[i + 1]]]
        distances = [dist for dist in D[limits[i] : limits[i + 1]]]
        results.append(list(zip(matches, distances)))
    return results

def __getstate__(self):
    data = faiss.serialize_index(self.faiss_index)
    return data

def __setstate__(self, data):
    self.faiss_index = faiss.deserialize_index(data)

Putting it together with search

    def query(self, query: str) -> t.List[IndexMatch[T]]:
        results = self.faiss_index.search([query], self.threshold)
        return [
            IndexMatch(int(distf), entry)
            for idx, distf in results[0]
            for entry in self._idx_to_entries[idx]
        ]

Dynamically selecting lookup type from build function

  @classmethod
    def build(cls: t.Type[Self], entries: t.Iterable[t.Tuple[str, T]]) -> Self:
        """
        Faiss has many potential options that we can choose based on the size of the index.
        """
        entry_list = list(entries)
        xn = len(entry_list)
        if xn < 1024:  # If small enough, just use brute force
            return super().build(entry_list)
        centroids = pick_n_centroids(xn)
        index = faiss.index_factory(DIMENSIONALITY, f"IVF{centroids}")  # TODO - use the same magic factory string as the old one does
        # Squelch warnings about not having enough points...
        index.cp.min_points_per_centroid = 1
        index.nprobe = 16  # 16-64 should be high enough accuracy for 1-10M
        ret = cls(faiss_index=index)
        for signal_str, entry in entry_list:
            ret._dedupe_and_add(signal_str, entry, add_to_faiss=False)
        xb = convert_pdq_strings_to_ndarray(tuple(s for s in ret._deduper))
        index.train(xb)
        index.add(xb)
        return ret

Test everything

Add a robust set of unittests for this functionality

  1. Test 0 entries
  2. Test sample set entries
  3. Test > brute force entries
  4. Serialization and deserialization
  5. Duplicate hashes return right thing
  6. Test the conditions from [pytx] No match results if creating a local_file with only 1 hash in it #1318

Rollout plan

After we confirm that everything is working as expected, we'll swap out the index class that the PDQ signal type uses by default. I think we can get away without a major version bump for this.

@Dcallies Dcallies added pdq Items related to the pdq libraries or reference implementations python-threatexchange Items related to the threatexchange python tool / library labels Aug 22, 2024
@zackjh3
Copy link

zackjh3 commented Oct 3, 2024

I will start a fix on this issue at the Hackathon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted pdq Items related to the pdq libraries or reference implementations python-threatexchange Items related to the threatexchange python tool / library
Projects
Status: No status
Development

No branches or pull requests

3 participants