diff --git a/scGNN.py b/scGNN.py index 45b26a1..58e78a7 100644 --- a/scGNN.py +++ b/scGNN.py @@ -548,7 +548,11 @@ def train(epoch, train_loader=train_loader, EMFlag=False, taskType='celltype', s for i in range(len(set(listResult))): clusterIndexList.append([]) for i in range(len(listResult)): - clusterIndexList[listResult[i]].append(i) + assignee = listResult[i] + # Avoid bugs for maxClusterNumber + if assignee == args.maxClusterNumber: + assignee = args.maxClusterNumber-1 + clusterIndexList[assignee].append(i) reconNew = np.zeros( (scData.features.shape[0], scData.features.shape[1])) diff --git a/util_function.py b/util_function.py index 5da097c..65f02a6 100644 --- a/util_function.py +++ b/util_function.py @@ -576,7 +576,7 @@ def trimClustering(listResult, minMemberinCluster=5, maxClusterNumber=30): size = len(set(listResult)) changeDict = {} for item in range(size): - if numDict[item] < minMemberinCluster and item >= maxClusterNumber: + if numDict[item] < minMemberinCluster or item >= maxClusterNumber: changeDict[item] = '' count = 0