From 9830d66b0980da1522f38c307dbc2b1f7a86a6bb Mon Sep 17 00:00:00 2001 From: JasterV Date: Thu, 21 May 2020 02:57:30 +0200 Subject: [PATCH] some refactors --- .gitignore | 2 +- src/python/sarscovhierarchy.py | 16 +++------------- src/python/utils/csv_table.py | 20 ++++++++------------ src/python/utils/fasta_map.py | 16 ++++++++++++---- src/python/utils/process_info.py | 6 +++--- test/test_csv_table.py | 9 +++++---- test/test_fasta_map.py | 5 ++++- 7 files changed, 36 insertions(+), 38 deletions(-) diff --git a/.gitignore b/.gitignore index c4b61cf..5b8f7d0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ __pycache__/ *.py[cod] *$py.class - +.env .vscode # Distribution / packaging diff --git a/src/python/sarscovhierarchy.py b/src/python/sarscovhierarchy.py index 690de88..246fa4f 100644 --- a/src/python/sarscovhierarchy.py +++ b/src/python/sarscovhierarchy.py @@ -5,33 +5,23 @@ from sys import argv from utils.csv_table import CsvTable from utils.fasta_map import FastaMap -from utils.process_info import ProcessInfo signal.signal(signal.SIGTSTP, signal.SIG_IGN) -def calculate_max_threads(max_length, num_samples): - piu = ProcessInfo(num_samples, max_length) - return piu.max_threads if piu.max_threads <= 3 else 3 - - def main(): data_dir = argv[1] csv_path = join(data_dir, "sequences.csv") fasta_path = join(data_dir, "sequences.fasta") - print("Reading and processing files...") + print("\nReading and processing files...") csv_table = CsvTable(csv_path).group_countries_by_median_length() ids = csv_table.values('Accession') fasta_map = FastaMap(fasta_path).filter(lambda item: item[0] in ids) print("Files processing finished!") - max_length = max(map(int, csv_table.values("Length"))) - num_samples = len(fasta_map) - max_threads_to_use = calculate_max_threads(max_length, num_samples) - - print("Building hierarchy...") - fasta_map.build_hierarchy(max_threads_to_use) + print("\nBuilding hierarchy...") + fasta_map.build_hierarchy() print("Done!") diff --git a/src/python/utils/csv_table.py b/src/python/utils/csv_table.py index 5c47d13..fa0a3de 100644 --- a/src/python/utils/csv_table.py +++ b/src/python/utils/csv_table.py @@ -6,7 +6,7 @@ from typing import List, Union, Dict from prettytable import PrettyTable -import utils.select +import utils.select as sel class CsvTable: @@ -17,31 +17,27 @@ class CsvTable: def __init__(self, arg: Union[str, List[Dict]]): if isinstance(arg, str): - self.__table = self.__read(arg) + self._table = self.__read(arg) elif isinstance(arg, list): - self.__table = arg + self._table = arg else: raise TypeError("Invalid Argument") def __getitem__(self, index): try: - return self.__table[index] + return self._table[index] except: raise IndexError("Index out of range") def __len__(self): - return len(self.__table) + return len(self._table) def __iter__(self): - for row in self.__table: + for row in self._table: yield row - @property - def table(self): - return self.__table - def __str__(self): - pretty_table = PrettyTable(list(self.__table[0].keys())) + pretty_table = PrettyTable(list(self._table[0].keys())) for row in self: pretty_table.add_row(row.values()) return str(pretty_table) @@ -68,7 +64,7 @@ class CsvTable: return CsvTable(filtered_data) def __get_average_row(self, values: list) -> Union[dict, List[dict]]: - median_value = utils.select.quick_select_median(values, index=1) + median_value = sel.quick_select_median(values, index=1) row = self[median_value[0]] geo_location = row['Geo_Location'] row['Geo_Location'] = geo_location.split(":")[0] \ diff --git a/src/python/utils/fasta_map.py b/src/python/utils/fasta_map.py index be6bb82..9f85e64 100644 --- a/src/python/utils/fasta_map.py +++ b/src/python/utils/fasta_map.py @@ -7,6 +7,7 @@ import collections import time from typing import Tuple, Dict, List, Callable, Union, Any +from utils.process_info import ProcessInfo import libs.seqalign as sq @@ -59,7 +60,13 @@ class FastaMap: data[rna_id] = rna return data - def _compare_all_samples(self, threads): + def _compare_all_samples(self): + # Calculate the number of threads that can be + # used in order to speed up the comparisons + max_length = max(map(len, self.__data.values())) + num_samples = len(self.__data) + threads = ProcessInfo(num_samples, max_length).max_threads + # Start the comparisons print("Performing comparisons...") start_time = time.time() ids = list(self.__data.keys()) @@ -67,16 +74,17 @@ class FastaMap: for i in range(len(ids) - 1) for j in range(i + 1, len(ids))] comparisons = sq.par_compare(to_compare, self.__data, str(threads)) - print(f"Comparisons performed in {time.time() - start_time:.3f} seconds!") + print( + f"Comparisons performed in {time.time() - start_time:.3f} seconds!") return comparisons - def build_hierarchy(self, threads) -> List[Union[Tuple[Any, ...], list]]: + def build_hierarchy(self) -> List[Union[Tuple[Any, ...], list]]: """ The function that is in charge of the comparison and the hierarchy of the samples :param threads_option: :return: """ - comparisons = self._compare_all_samples(threads) + comparisons = self._compare_all_samples() table = self._to_dict(comparisons) levels = [tuple(table.keys())] diff --git a/src/python/utils/process_info.py b/src/python/utils/process_info.py index 075aedf..ef0b810 100644 --- a/src/python/utils/process_info.py +++ b/src/python/utils/process_info.py @@ -12,7 +12,7 @@ import psutil class ProcessInfo: """ - informer and delimiter of system resources. + Calculates system resources. """ def __init__(self, num_samples, max_length): @@ -40,7 +40,7 @@ class ProcessInfo: @property def max_threads(self): - threads_available = self.num_logic_cores + threads_available = self.num_logic_cores if self.num_logic_cores <= 3 else 3 threads = math.floor(self.mem_available / self.max_mem_per_comparison) max_threads = threads if threads <= threads_available else threads_available - return max_threads if max_threads >= 1 else 1 \ No newline at end of file + return max_threads if max_threads >= 1 else 1 diff --git a/test/test_csv_table.py b/test/test_csv_table.py index b2486f6..963097b 100644 --- a/test/test_csv_table.py +++ b/test/test_csv_table.py @@ -1,11 +1,11 @@ from unittest import TestCase -from src.python.utils.csv_table import CsvTable +from utils.csv_table import CsvTable class TestCsvTable(TestCase): def test_values(self): - path = "../data/data_test/sequences.csv" + path = "data/data_test/sequences.csv" accession = ["MT292569", "MT292570", "MT292571", "MT292572", "MT292574", "MT292575", "MT292576", "MT292573", "MT292577", "MT292578", "MT292579", "MT292580", "MT292581", "MT292582", "MT256917", @@ -17,7 +17,7 @@ class TestCsvTable(TestCase): self.assertTrue(all([x in accession for x in csv_table.values("Accession")])) def test_group_countries_by_median_length(self): - path = "../data/data_test/sequences.csv" + path = "data/data_test/sequences.csv" row_china = {"Accession": "MT259228", "Release_Date": "2020-03-30T00:00:00Z", "Species": "Severe acute respiratory syndrome-related coronavirus", @@ -39,4 +39,5 @@ class TestCsvTable(TestCase): list_cases_test = CsvTable(list_cases) csv_table = CsvTable(path) list_csv = csv_table.group_countries_by_median_length() - self.assertTrue(all(x in list_csv.table for x in list_cases_test.table)) + self.assertTrue(all(x in list_csv._table for x in list_cases_test._table)) + diff --git a/test/test_fasta_map.py b/test/test_fasta_map.py index a7ebcc7..70116dc 100644 --- a/test/test_fasta_map.py +++ b/test/test_fasta_map.py @@ -1,6 +1,6 @@ from unittest import TestCase -import src.python.libs.seqalign as sq +import libs.seqalign as sq class TestFastaMap(TestCase): @@ -14,14 +14,17 @@ class TestFastaMap(TestCase): seq2 = "AACG" expected = 2 self.assertEqual(expected, sq.compare_samples(seq1, seq2)) + seq1 = "JDSLFA" seq2 = "JALFDSA" expected = 6 self.assertEqual(expected, sq.compare_samples(seq1, seq2)) + seq1 = "DSAJKJFHJAKHDIOUZVJCXMVCZIOIUOWUQRIEUWQIPIDSFSDKZXV" seq2 = "FKSJAFKJIOTIGHLKJVMCXXZCMVLJASIRUWQOIUTPQWURIIPOQ" expected = 46 self.assertEqual(expected, sq.compare_samples(seq1, seq2)) + seq1 = "KFDSKAJFJPOTIPWQUIMXZMVMZXBVM" seq2 = "FKDSPOIQTUITYSKZMV" expected = 31