some refactors

This commit is contained in:
JasterV 2020-05-21 02:57:30 +02:00
parent 510cf1fc8b
commit 9830d66b09
7 changed files with 36 additions and 38 deletions

2
.gitignore vendored
View file

@ -2,7 +2,7 @@
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
.env
.vscode .vscode
# Distribution / packaging # Distribution / packaging

View file

@ -5,33 +5,23 @@ from sys import argv
from utils.csv_table import CsvTable from utils.csv_table import CsvTable
from utils.fasta_map import FastaMap from utils.fasta_map import FastaMap
from utils.process_info import ProcessInfo
signal.signal(signal.SIGTSTP, signal.SIG_IGN) signal.signal(signal.SIGTSTP, signal.SIG_IGN)
def calculate_max_threads(max_length, num_samples):
piu = ProcessInfo(num_samples, max_length)
return piu.max_threads if piu.max_threads <= 3 else 3
def main(): def main():
data_dir = argv[1] data_dir = argv[1]
csv_path = join(data_dir, "sequences.csv") csv_path = join(data_dir, "sequences.csv")
fasta_path = join(data_dir, "sequences.fasta") fasta_path = join(data_dir, "sequences.fasta")
print("Reading and processing files...") print("\nReading and processing files...")
csv_table = CsvTable(csv_path).group_countries_by_median_length() csv_table = CsvTable(csv_path).group_countries_by_median_length()
ids = csv_table.values('Accession') ids = csv_table.values('Accession')
fasta_map = FastaMap(fasta_path).filter(lambda item: item[0] in ids) fasta_map = FastaMap(fasta_path).filter(lambda item: item[0] in ids)
print("Files processing finished!") print("Files processing finished!")
max_length = max(map(int, csv_table.values("Length"))) print("\nBuilding hierarchy...")
num_samples = len(fasta_map) fasta_map.build_hierarchy()
max_threads_to_use = calculate_max_threads(max_length, num_samples)
print("Building hierarchy...")
fasta_map.build_hierarchy(max_threads_to_use)
print("Done!") print("Done!")

View file

@ -6,7 +6,7 @@ from typing import List, Union, Dict
from prettytable import PrettyTable from prettytable import PrettyTable
import utils.select import utils.select as sel
class CsvTable: class CsvTable:
@ -17,31 +17,27 @@ class CsvTable:
def __init__(self, arg: Union[str, List[Dict]]): def __init__(self, arg: Union[str, List[Dict]]):
if isinstance(arg, str): if isinstance(arg, str):
self.__table = self.__read(arg) self._table = self.__read(arg)
elif isinstance(arg, list): elif isinstance(arg, list):
self.__table = arg self._table = arg
else: else:
raise TypeError("Invalid Argument") raise TypeError("Invalid Argument")
def __getitem__(self, index): def __getitem__(self, index):
try: try:
return self.__table[index] return self._table[index]
except: except:
raise IndexError("Index out of range") raise IndexError("Index out of range")
def __len__(self): def __len__(self):
return len(self.__table) return len(self._table)
def __iter__(self): def __iter__(self):
for row in self.__table: for row in self._table:
yield row yield row
@property
def table(self):
return self.__table
def __str__(self): def __str__(self):
pretty_table = PrettyTable(list(self.__table[0].keys())) pretty_table = PrettyTable(list(self._table[0].keys()))
for row in self: for row in self:
pretty_table.add_row(row.values()) pretty_table.add_row(row.values())
return str(pretty_table) return str(pretty_table)
@ -68,7 +64,7 @@ class CsvTable:
return CsvTable(filtered_data) return CsvTable(filtered_data)
def __get_average_row(self, values: list) -> Union[dict, List[dict]]: def __get_average_row(self, values: list) -> Union[dict, List[dict]]:
median_value = utils.select.quick_select_median(values, index=1) median_value = sel.quick_select_median(values, index=1)
row = self[median_value[0]] row = self[median_value[0]]
geo_location = row['Geo_Location'] geo_location = row['Geo_Location']
row['Geo_Location'] = geo_location.split(":")[0] \ row['Geo_Location'] = geo_location.split(":")[0] \

View file

@ -7,6 +7,7 @@
import collections import collections
import time import time
from typing import Tuple, Dict, List, Callable, Union, Any from typing import Tuple, Dict, List, Callable, Union, Any
from utils.process_info import ProcessInfo
import libs.seqalign as sq import libs.seqalign as sq
@ -59,7 +60,13 @@ class FastaMap:
data[rna_id] = rna data[rna_id] = rna
return data return data
def _compare_all_samples(self, threads): def _compare_all_samples(self):
# Calculate the number of threads that can be
# used in order to speed up the comparisons
max_length = max(map(len, self.__data.values()))
num_samples = len(self.__data)
threads = ProcessInfo(num_samples, max_length).max_threads
# Start the comparisons
print("Performing comparisons...") print("Performing comparisons...")
start_time = time.time() start_time = time.time()
ids = list(self.__data.keys()) ids = list(self.__data.keys())
@ -67,16 +74,17 @@ class FastaMap:
for i in range(len(ids) - 1) for i in range(len(ids) - 1)
for j in range(i + 1, len(ids))] for j in range(i + 1, len(ids))]
comparisons = sq.par_compare(to_compare, self.__data, str(threads)) comparisons = sq.par_compare(to_compare, self.__data, str(threads))
print(f"Comparisons performed in {time.time() - start_time:.3f} seconds!") print(
f"Comparisons performed in {time.time() - start_time:.3f} seconds!")
return comparisons return comparisons
def build_hierarchy(self, threads) -> List[Union[Tuple[Any, ...], list]]: def build_hierarchy(self) -> List[Union[Tuple[Any, ...], list]]:
""" """
The function that is in charge of the comparison and the hierarchy of the samples The function that is in charge of the comparison and the hierarchy of the samples
:param threads_option: :param threads_option:
:return: :return:
""" """
comparisons = self._compare_all_samples(threads) comparisons = self._compare_all_samples()
table = self._to_dict(comparisons) table = self._to_dict(comparisons)
levels = [tuple(table.keys())] levels = [tuple(table.keys())]

View file

@ -12,7 +12,7 @@ import psutil
class ProcessInfo: class ProcessInfo:
""" """
informer and delimiter of system resources. Calculates system resources.
""" """
def __init__(self, num_samples, max_length): def __init__(self, num_samples, max_length):
@ -40,7 +40,7 @@ class ProcessInfo:
@property @property
def max_threads(self): def max_threads(self):
threads_available = self.num_logic_cores threads_available = self.num_logic_cores if self.num_logic_cores <= 3 else 3
threads = math.floor(self.mem_available / self.max_mem_per_comparison) threads = math.floor(self.mem_available / self.max_mem_per_comparison)
max_threads = threads if threads <= threads_available else threads_available max_threads = threads if threads <= threads_available else threads_available
return max_threads if max_threads >= 1 else 1 return max_threads if max_threads >= 1 else 1

View file

@ -1,11 +1,11 @@
from unittest import TestCase from unittest import TestCase
from src.python.utils.csv_table import CsvTable from utils.csv_table import CsvTable
class TestCsvTable(TestCase): class TestCsvTable(TestCase):
def test_values(self): def test_values(self):
path = "../data/data_test/sequences.csv" path = "data/data_test/sequences.csv"
accession = ["MT292569", "MT292570", "MT292571", "MT292572", "MT292574", accession = ["MT292569", "MT292570", "MT292571", "MT292572", "MT292574",
"MT292575", "MT292576", "MT292573", "MT292577", "MT292578", "MT292575", "MT292576", "MT292573", "MT292577", "MT292578",
"MT292579", "MT292580", "MT292581", "MT292582", "MT256917", "MT292579", "MT292580", "MT292581", "MT292582", "MT256917",
@ -17,7 +17,7 @@ class TestCsvTable(TestCase):
self.assertTrue(all([x in accession for x in csv_table.values("Accession")])) self.assertTrue(all([x in accession for x in csv_table.values("Accession")]))
def test_group_countries_by_median_length(self): def test_group_countries_by_median_length(self):
path = "../data/data_test/sequences.csv" path = "data/data_test/sequences.csv"
row_china = {"Accession": "MT259228", row_china = {"Accession": "MT259228",
"Release_Date": "2020-03-30T00:00:00Z", "Release_Date": "2020-03-30T00:00:00Z",
"Species": "Severe acute respiratory syndrome-related coronavirus", "Species": "Severe acute respiratory syndrome-related coronavirus",
@ -39,4 +39,5 @@ class TestCsvTable(TestCase):
list_cases_test = CsvTable(list_cases) list_cases_test = CsvTable(list_cases)
csv_table = CsvTable(path) csv_table = CsvTable(path)
list_csv = csv_table.group_countries_by_median_length() list_csv = csv_table.group_countries_by_median_length()
self.assertTrue(all(x in list_csv.table for x in list_cases_test.table)) self.assertTrue(all(x in list_csv._table for x in list_cases_test._table))

View file

@ -1,6 +1,6 @@
from unittest import TestCase from unittest import TestCase
import src.python.libs.seqalign as sq import libs.seqalign as sq
class TestFastaMap(TestCase): class TestFastaMap(TestCase):
@ -14,14 +14,17 @@ class TestFastaMap(TestCase):
seq2 = "AACG" seq2 = "AACG"
expected = 2 expected = 2
self.assertEqual(expected, sq.compare_samples(seq1, seq2)) self.assertEqual(expected, sq.compare_samples(seq1, seq2))
seq1 = "JDSLFA" seq1 = "JDSLFA"
seq2 = "JALFDSA" seq2 = "JALFDSA"
expected = 6 expected = 6
self.assertEqual(expected, sq.compare_samples(seq1, seq2)) self.assertEqual(expected, sq.compare_samples(seq1, seq2))
seq1 = "DSAJKJFHJAKHDIOUZVJCXMVCZIOIUOWUQRIEUWQIPIDSFSDKZXV" seq1 = "DSAJKJFHJAKHDIOUZVJCXMVCZIOIUOWUQRIEUWQIPIDSFSDKZXV"
seq2 = "FKSJAFKJIOTIGHLKJVMCXXZCMVLJASIRUWQOIUTPQWURIIPOQ" seq2 = "FKSJAFKJIOTIGHLKJVMCXXZCMVLJASIRUWQOIUTPQWURIIPOQ"
expected = 46 expected = 46
self.assertEqual(expected, sq.compare_samples(seq1, seq2)) self.assertEqual(expected, sq.compare_samples(seq1, seq2))
seq1 = "KFDSKAJFJPOTIPWQUIMXZMVMZXBVM" seq1 = "KFDSKAJFJPOTIPWQUIMXZMVMZXBVM"
seq2 = "FKDSPOIQTUITYSKZMV" seq2 = "FKDSPOIQTUITYSKZMV"
expected = 31 expected = 31