Source code for nlp_architect.models.cross_doc_coref.system.sieves.run_sieve_system

# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

import logging
import time

from nlp_architect.common.cdc.cluster import Clusters
from nlp_architect.common.cdc.topics import Topic
from nlp_architect.models.cross_doc_coref.system.sieves.sieves import SieveClusterMerger
from nlp_architect.models.cross_doc_coref.system.sieves_container_init import (
    SievesContainerInitialization,
)

logger = logging.getLogger(__name__)


[docs]class RunSystemsSuper(object): def __init__(self, topic: Topic): self.sieves = [] self.results_dict = dict() self.results_ordered = [] logger.info("loading topic %s, total mentions: %d", topic.topic_id, len(topic.mentions)) self.clusters = Clusters(topic.topic_id, topic.mentions)
[docs] @staticmethod def set_sieves_from_config(config, get_rel_extraction): sieves = [] for _type_tup in config.sieves_order: sieves.append(SieveClusterMerger(_type_tup, get_rel_extraction(_type_tup[0]))) return sieves
[docs] def run_deterministic(self): # pylint: disable=too-many-nested-blocks for sieve in self.sieves: start = time.time() clusters_changed = True merge_count = 0 while clusters_changed: clusters_changed = False clusters_size = len(self.clusters.clusters_list) for i in range(0, clusters_size): cluster_i = self.clusters.clusters_list[i] if cluster_i.merged: continue for j in range(i + 1, clusters_size): cluster_j = self.clusters.clusters_list[j] if cluster_j.merged: continue if cluster_i is not cluster_j: criterion = sieve.run_sieve(cluster_i, cluster_j) if criterion: merge_count += 1 clusters_changed = True cluster_i.merge_clusters(cluster_j) cluster_j.merged = True if clusters_changed: self.clusters.clean_clusters() end = time.time() took = end - start logger.info( "Total of %d clusters merged using method: %s, took: %.4f sec", merge_count, str(sieve.excepted_relation), took, ) return self.clusters
[docs] def get_results(self): return self.results_ordered
[docs]class RunSystemsEntity(RunSystemsSuper): def __init__(self, topic: Topic, resources): super(RunSystemsEntity, self).__init__(topic) self.sieves = self.set_sieves_from_config( resources.entity_config, resources.get_module_from_relation )
[docs]class RunSystemsEvent(RunSystemsSuper): def __init__(self, topic, resources): super(RunSystemsEvent, self).__init__(topic) self.sieves = self.set_sieves_from_config( resources.event_config, resources.get_module_from_relation )
# pylint: disable=no-else-return
[docs]def get_run_system(topic: Topic, resource: SievesContainerInitialization, eval_type: str): if eval_type.lower() == "entity": return RunSystemsEntity(topic, resource) if eval_type.lower() == "event": return RunSystemsEvent(topic, resource) else: raise AttributeError(eval_type + " Not supported!")