sort embeddings by name (case insensitive)

This commit is contained in:
Brad Smith
2023-04-08 15:58:00 -04:00
parent 22bcc7be42
commit 27b9ec60e4

View File

@@ -2,7 +2,7 @@ import os
import sys import sys
import traceback import traceback
import inspect import inspect
from collections import namedtuple from collections import namedtuple, OrderedDict
import torch import torch
import tqdm import tqdm
@@ -108,7 +108,7 @@ class DirWithTextualInversionEmbeddings:
class EmbeddingDatabase: class EmbeddingDatabase:
def __init__(self): def __init__(self):
self.ids_lookup = {} self.ids_lookup = {}
self.word_embeddings = {} self.word_embeddings = OrderedDict()
self.skipped_embeddings = {} self.skipped_embeddings = {}
self.expected_shape = -1 self.expected_shape = -1
self.embedding_dirs = {} self.embedding_dirs = {}
@@ -233,6 +233,9 @@ class EmbeddingDatabase:
self.load_from_dir(embdir) self.load_from_dir(embdir)
embdir.update() embdir.update()
# re-sort word_embeddings because load_from_dir may not load in alphabetic order.
self.word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys())) displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
if self.previously_displayed_embeddings != displayed_embeddings: if self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings self.previously_displayed_embeddings = displayed_embeddings