# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
from typing import List
from nlp_architect.common.cdc.mention_data import MentionData
[docs]class Cluster(object):
def __init__(self, coref_chain: int = -1) -> None:
"""
Object represent a set of mentions with same coref chain id
Args:
coref_chain (int): the cluster id/coref_chain value
"""
self.mentions = []
self.cluster_strings = []
self.merged = False
self.coref_chain = coref_chain
self.mentions_corefs = set()
[docs] def get_mentions(self):
return self.mentions
[docs] def add_mention(self, mention: MentionData) -> None:
if mention is not None:
mention.predicted_coref_chain = self.coref_chain
self.mentions.append(mention)
self.cluster_strings.append(mention.tokens_str)
self.mentions_corefs.add(mention.coref_chain)
[docs] def merge_clusters(self, cluster) -> None:
"""
Args:
cluster: cluster to merge this cluster with
"""
for mention in cluster.mentions:
mention.predicted_coref_chain = self.coref_chain
self.mentions.extend(cluster.mentions)
self.cluster_strings.extend(cluster.cluster_strings)
self.mentions_corefs.update(cluster.mentions_corefs)
[docs] def get_cluster_id(self) -> str:
"""
Returns:
A generated cluster unique Id created from cluster mentions ids
"""
return "$".join([mention.mention_id for mention in self.mentions])
[docs]class Clusters(object):
cluster_coref_chain = 1000
def __init__(self, topic_id: str, mentions: List[MentionData] = None) -> None:
"""
Args:
mentions: ``list[MentionData]``, required
The initial mentions to create the clusters from
"""
self.clusters_list = []
self.topic_id = topic_id
self.set_initial_clusters(mentions)
[docs] def set_initial_clusters(self, mentions: List[MentionData]) -> None:
"""
Args:
mentions: ``list[MentionData]``, required
The initial mentions to create the clusters from
"""
if mentions:
for mention in mentions:
cluster = Cluster(Clusters.cluster_coref_chain)
cluster.add_mention(mention)
self.clusters_list.append(cluster)
Clusters.cluster_coref_chain += 1
[docs] def clean_clusters(self) -> None:
"""
Remove all clusters that were already merged with other clusters
"""
self.clusters_list = [cluster for cluster in self.clusters_list if not cluster.merged]
[docs] def set_coref_chain_to_mentions(self) -> None:
"""
Give all cluster mentions the same coref ID as cluster coref chain ID
"""
for cluster in self.clusters_list:
for mention in cluster.mentions:
mention.predicted_coref_chain = str(cluster.coref_chain)
[docs] def add_cluster(self, cluster: Cluster) -> None:
self.clusters_list.append(cluster)
[docs] def add_clusters(self, clusters) -> None:
for cluster in clusters.clusters_list:
self.clusters_list.append(cluster)