Skip to content

Creating Extensions

HNSW is designed to be extensible, allowing you to create custom extensions that enhance its functionality. This guide will walk you through the process of creating your own extension.

Design Principles

When creating an extension for HNSW, keep the following design principles in mind:

  1. Composition over inheritance: Extensions should wrap the core HNSW graph rather than extending it through inheritance.
  2. Type safety: Use Go generics to ensure type safety.
  3. Minimal dependencies: Extensions should have minimal dependencies on external packages.
  4. Performance: Extensions should be designed with performance in mind, avoiding unnecessary allocations and computations.
  5. Consistency: Extensions should follow the same API patterns as the core HNSW library.

Extension Structure

A typical HNSW extension consists of the following components:

  1. Wrapper type: A struct that wraps the core HNSW graph and adds additional functionality.
  2. Storage interface: An interface that defines how to store and retrieve extension-specific data.
  3. Storage implementation: One or more implementations of the storage interface.
  4. Helper functions: Functions that make it easier to use the extension.

Example: Creating a Simple Extension

Let's create a simple extension that adds tagging functionality to the HNSW graph. This extension will allow you to tag nodes and search for nodes with specific tags.

Step 1: Define the Extension Package

Create a new file called tags.go in a new package:

package tags

import (
    "github.com/TFMV/hnsw"
)

// Tags represents a set of tags for a node.
type Tags []string

// TagStore defines the interface for storing and retrieving tags.
type TagStore[K comparable] interface {
    // Add adds tags for a key.
    Add(key K, tags Tags) error

    // Get retrieves tags for a key.
    Get(key K) (Tags, error)

    // Delete removes tags for a key.
    Delete(key K) error

    // Search returns keys that have all the specified tags.
    Search(tags ...string) ([]K, error)
}

Step 2: Implement a Storage Backend

Now, let's implement an in-memory storage backend for our tags:

// MemoryTagStore is an in-memory implementation of TagStore.
type MemoryTagStore[K comparable] struct {
    tags map[K]Tags
}

// NewMemoryTagStore creates a new MemoryTagStore.
func NewMemoryTagStore[K comparable]() *MemoryTagStore[K] {
    return &MemoryTagStore[K]{
        tags: make(map[K]Tags),
    }
}

// Add adds tags for a key.
func (s *MemoryTagStore[K]) Add(key K, tags Tags) error {
    s.tags[key] = tags
    return nil
}

// Get retrieves tags for a key.
func (s *MemoryTagStore[K]) Get(key K) (Tags, error) {
    tags, ok := s.tags[key]
    if !ok {
        return nil, fmt.Errorf("no tags found for key %v", key)
    }
    return tags, nil
}

// Delete removes tags for a key.
func (s *MemoryTagStore[K]) Delete(key K) error {
    delete(s.tags, key)
    return nil
}

// Search returns keys that have all the specified tags.
func (s *MemoryTagStore[K]) Search(tags ...string) ([]K, error) {
    var result []K

    for key, nodeTags := range s.tags {
        if containsAll(nodeTags, tags) {
            result = append(result, key)
        }
    }

    return result, nil
}

// containsAll returns true if all items in subset are in set.
func containsAll(set, subset []string) bool {
    for _, item := range subset {
        found := false
        for _, setItem := range set {
            if setItem == item {
                found = true
                break
            }
        }
        if !found {
            return false
        }
    }
    return true
}

Step 3: Create a Wrapper Type

Now, let's create a wrapper type that combines the HNSW graph with our tag store:

// TaggedGraph wraps an HNSW graph and adds tagging functionality.
type TaggedGraph[K comparable] struct {
    graph    *hnsw.Graph[K]
    tagStore TagStore[K]
}

// NewTaggedGraph creates a new TaggedGraph.
func NewTaggedGraph[K comparable]() *TaggedGraph[K] {
    return &TaggedGraph[K]{
        graph:    hnsw.NewGraph[K](),
        tagStore: NewMemoryTagStore[K](),
    }
}

// Add adds a node with tags to the graph.
func (g *TaggedGraph[K]) Add(node hnsw.Node[K], tags Tags) error {
    // Add the node to the graph
    if err := g.graph.Add(node); err != nil {
        return err
    }

    // Add the tags
    return g.tagStore.Add(node.Key, tags)
}

// Delete removes a node and its tags from the graph.
func (g *TaggedGraph[K]) Delete(key K) error {
    // Delete the node from the graph
    if err := g.graph.Delete(key); err != nil {
        return err
    }

    // Delete the tags
    return g.tagStore.Delete(key)
}

// Get retrieves a node and its tags from the graph.
func (g *TaggedGraph[K]) Get(key K) (hnsw.Node[K], Tags, error) {
    // Get the node from the graph
    node, err := g.graph.Get(key)
    if err != nil {
        return hnsw.Node[K]{}, nil, err
    }

    // Get the tags
    tags, err := g.tagStore.Get(key)
    if err != nil {
        return node, nil, err
    }

    return node, tags, nil
}

// Search searches for the nearest neighbors of the query vector.
func (g *TaggedGraph[K]) Search(query []float32, k int) ([]hnsw.SearchResult[K], error) {
    return g.graph.Search(query, k)
}

// SearchWithTags searches for the nearest neighbors of the query vector that have all the specified tags.
func (g *TaggedGraph[K]) SearchWithTags(query []float32, k int, tags ...string) ([]hnsw.SearchResult[K], error) {
    // Find nodes with the specified tags
    keys, err := g.tagStore.Search(tags...)
    if err != nil {
        return nil, err
    }

    // If no nodes have the specified tags, return an empty result
    if len(keys) == 0 {
        return []hnsw.SearchResult[K]{}, nil
    }

    // Search for the nearest neighbors among the nodes with the specified tags
    return g.graph.SearchWithFilter(query, k, func(key K) bool {
        for _, k := range keys {
            if k == key {
                return true
            }
        }
        return false
    })
}

Step 4: Add Tests

It's important to test your extension to ensure it works correctly:

package tags

import (
    "testing"

    "github.com/TFMV/hnsw"
)

func TestTaggedGraph(t *testing.T) {
    // Create a new tagged graph
    graph := NewTaggedGraph[int]()

    // Add some nodes with tags
    nodes := []hnsw.Node[int]{
        {Key: 1, Value: []float32{0.1, 0.2, 0.3}},
        {Key: 2, Value: []float32{0.2, 0.3, 0.4}},
        {Key: 3, Value: []float32{0.3, 0.4, 0.5}},
    }

    tags := []Tags{
        {"electronics", "smartphone"},
        {"electronics", "laptop"},
        {"clothing", "shirt"},
    }

    for i, node := range nodes {
        err := graph.Add(node, tags[i])
        if err != nil {
            t.Fatalf("Error adding node: %v", err)
        }
    }

    // Test retrieving a node with tags
    node, nodeTags, err := graph.Get(1)
    if err != nil {
        t.Fatalf("Error getting node: %v", err)
    }

    if node.Key != 1 {
        t.Errorf("Expected key 1, got %v", node.Key)
    }

    if len(nodeTags) != 2 || nodeTags[0] != "electronics" || nodeTags[1] != "smartphone" {
        t.Errorf("Expected tags [electronics smartphone], got %v", nodeTags)
    }

    // Test searching with tags
    query := []float32{0.15, 0.25, 0.35}
    results, err := graph.SearchWithTags(query, 2, "electronics")
    if err != nil {
        t.Fatalf("Error searching with tags: %v", err)
    }

    if len(results) != 2 {
        t.Errorf("Expected 2 results, got %d", len(results))
    }

    // The closest node should be node 1
    if results[0].Key != 1 {
        t.Errorf("Expected key 1, got %v", results[0].Key)
    }

    // Test searching with multiple tags
    results, err = graph.SearchWithTags(query, 2, "electronics", "smartphone")
    if err != nil {
        t.Fatalf("Error searching with tags: %v", err)
    }

    if len(results) != 1 {
        t.Errorf("Expected 1 result, got %d", len(results))
    }

    if results[0].Key != 1 {
        t.Errorf("Expected key 1, got %v", results[0].Key)
    }

    // Test deleting a node
    err = graph.Delete(1)
    if err != nil {
        t.Fatalf("Error deleting node: %v", err)
    }

    // Verify that the node is deleted
    _, _, err = graph.Get(1)
    if err == nil {
        t.Errorf("Expected error getting deleted node, got nil")
    }
}

Step 5: Document Your Extension

Finally, document your extension to make it easier for others to use:

# Tag Extension

The Tag Extension allows you to tag nodes in the HNSW graph and search for nodes with specific tags.

## Features

- Add tags to nodes
- Retrieve tags for nodes
- Search for nodes with specific tags
- Combine tag filtering with vector search

## Usage

```go
// Create a new tagged graph
graph := tags.NewTaggedGraph[int]()

// Add a node with tags
err := graph.Add(hnsw.Node[int]{
    Key:   1,
    Value: []float32{0.1, 0.2, 0.3},
}, tags.Tags{"electronics", "smartphone"})

// Search for nodes with tags
results, err := graph.SearchWithTags(query, 5, "electronics", "smartphone")

Best Practices

  • Keep the number of tags per node reasonable
  • Use specific tags for better filtering
  • Consider implementing a custom tag store for very large datasets

Best Practices for Extension Development

When developing extensions for HNSW, follow these best practices:

  1. Keep it simple: Extensions should do one thing and do it well.
  2. Document thoroughly: Provide clear documentation and examples.
  3. Test extensively: Write comprehensive tests to ensure your extension works correctly.
  4. Consider performance: Optimize your extension for performance, especially for large datasets.
  5. Follow Go conventions: Follow standard Go conventions for naming, error handling, and documentation.
  6. Provide flexibility: Allow users to customize the behavior of your extension.
  7. Maintain compatibility: Ensure your extension works with the latest version of HNSW.

Contributing Extensions

If you've created a useful extension for HNSW, consider contributing it back to the project. This helps the community and ensures your extension is maintained alongside the core library.

To contribute an extension:

  1. Fork the HNSW repository
  2. Add your extension to the hnsw-extensions directory
  3. Write comprehensive tests and documentation
  4. Submit a pull request

Next Steps