Skip to content
Snippets Groups Projects

k-NN with Strings and Graph Visualization

  • Clone with SSH
  • Clone with HTTPS
  • Embed
  • Share
    The snippet can be accessed without any authentication.
    Authored by Christof Kaufmann

    This demonstrates how to implement the k-Nearest neighbors algorithm with string data. Some different distance metrics can be tried out and customized.

    In addition a directed neighbors graph is created and visualized.

    Edited
    string-neighbors-graph.py 4.39 KiB
    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    """Demo for k-Nearest Neighbors with Strings
    
    This demonstrates how to implement the k-Nearest neighbors algorithm with string data. Some
    different distance metrics can be tried out and customized.
    
    In addition a directed neighbors graph is created and visualized.
    
    Prerequisites can be installed with:
    
        mamba install networkx pyvis
        pip install strsimpy
    """
    
    # %% imports
    import numpy as np
    import pandas as pd
    
    # try out some different string distance metrics
    # 1. Edit distance:
    # from strsimpy.levenshtein import Levenshtein
    # levenshtein = Levenshtein()
    # distance = levenshtein.distance
    
    # 2. Some experimental "perceived" distance
    # from strsimpy import SIFT4
    # sift = SIFT4()
    # distance = sift.distance
    
    # 3. Weighted edit distance, with custom costs
    from strsimpy.weighted_levenshtein import WeightedLevenshtein
    def insertion_deletion_cost(char):
        # additional spaces, dashes or underscores should not make a large difference
        # e.g. "1-123" ←→ "1123" has a distance of 0.1
        #      "1A123" ←→ "1123" has a distance of 0.75
        if char.isspace() or char in '-_':
            return 0.1
        return 0.75
    
    def substitution_cost(a, b):
        # same type of letter should not make a large difference
        # e.g. "123QA7" ←→ "123QB8" has a distance of 0.2
        #      "123QA7" ←→ "123QBX" has a distance of 1.1
        if (a.isdigit() and b.isdigit()) or \
           (a.isalpha() and b.isalpha()) or \
           (a.isspace() and b.isspace()):
            return 0.1
        return 1.0
    
    weighted_levenshtein = WeightedLevenshtein(
        substitution_cost_fn=substitution_cost,
        insertion_cost_fn=insertion_deletion_cost,
        deletion_cost_fn=insertion_deletion_cost)
    
    distance = weighted_levenshtein.distance
    
    # 4. n-Grams
    # from strsimpy.qgram import QGram
    # qgram = QGram(2)
    # distance = qgram.distance
    
    # 5. Overlap Coefficient, similar to n-Grams
    # from strsimpy.overlap_coefficient import OverlapCoefficient
    # overlapCoefficient = OverlapCoefficient()
    # distance = overlapCoefficient.distance
    
    # %% some test data
    data = [
        # Hyundai part numbers
        '99490-ADE50',
        '1KF40-AK700',
        '66700-1P001',
        '66797-1P001',
        '65110-1P700',
        '65230-1P500',
        '56900-1K000EQ',
        '11254-06161',
        '85020-1P000',
        '96550-1K100',
        '96200-1P500',
        'DP370-APU044CH0U',
        'DP370-APU060SH1U',
        'LP370-APE060SH1',
        '11290-08206B',
        '49580-2H366K',
        '11200-06453',
        '91568-45000',
        '12291-04103',
        '11250-08203',
        # Smart part numbers
        'A  4537513400',
        'A  4537513000',
        'A  4537512500',
        'A  4532920800',
        'Q0012473V001000000',
        'A  2711420172',
        'N910143006000',
        # Chery part numbers
        'Q1840880',
        '481H-1011030BA',
        '481H-1002041',
        '481FC-BJ1002001AA',
        'A11-3810010BB',
        'M11-8107010BA',
        'M11-6101350',
        # DOIs
        '10.1093/ajae/aaq063',
        '10.1371/journal.pgen.1001111',
        '10.3139/9783446463554',
        '10.5555/3378999',
        # ISBNs
        '978-3-16-148410-0',
        '978-1-4920-3264-9',
        '9781492032649',
        '1-56619-909-3',
        '1566199093',
    ]
    
    
    # %% simple k-NN test
    k = 3
    test = 'Q0012473V001000000'
    raw_dists = [distance(test, s) for s in data]
    dists = pd.Series(raw_dists, index=data).sort_values()
    
    # select k neighbors (or more if they have the same distances)
    neighbors = dists[dists <= dists.iloc[k-1]]
    neighbors
    
    
    # %% create a neighbors graph
    import networkx as nx
    
    graph = nx.DiGraph()
    graph.add_nodes_from(data)
    
    k = 2
    normalize = False
    for i, sample in enumerate(data):
        # distances to other samples
        other_data = data[:i] + data[i+1:]
        raw_dists = [distance(sample, s) for s in other_data]
        dists = pd.Series(raw_dists, index=other_data).sort_values()
    
        if normalize:
            lengths = np.maximum(dists.index.str.len(), len(sample))
            dists /= lengths * 10
    
        # select k neighbors (or more if they have the same distances)
        dist_at_k = dists.iloc[k-1] if len(dists) >= k else dists.iloc[-1]
        neighbors = dists[dists <= dist_at_k]
    
        # add weighted neighbors as edges in a graph
        for n, d in neighbors.items():
            graph.add_edge(sample, n, weight=1/d)
    
    
    # %% interactive visualization of the neighbors graph
    from pyvis.network import Network
    
    # notebook=True should work in Jupyter Notebooks (but not in VS Code)
    nt = Network(width='1800px', height='900px', notebook=False, cdn_resources='remote', directed=True)
    nt.from_nx(graph.copy())
    nt.toggle_physics(True)
    nt.show('nx.html')
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Please register or to comment