mirror of
https://codeberg.org/JasterV/sarscov-hierarchy.git
synced 2026-04-26 18:10:08 +00:00
some refactors
This commit is contained in:
parent
510cf1fc8b
commit
9830d66b09
7 changed files with 36 additions and 38 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -2,7 +2,7 @@
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
.env
|
||||||
.vscode
|
.vscode
|
||||||
|
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
|
|
|
||||||
|
|
@ -5,33 +5,23 @@ from sys import argv
|
||||||
|
|
||||||
from utils.csv_table import CsvTable
|
from utils.csv_table import CsvTable
|
||||||
from utils.fasta_map import FastaMap
|
from utils.fasta_map import FastaMap
|
||||||
from utils.process_info import ProcessInfo
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTSTP, signal.SIG_IGN)
|
signal.signal(signal.SIGTSTP, signal.SIG_IGN)
|
||||||
|
|
||||||
|
|
||||||
def calculate_max_threads(max_length, num_samples):
|
|
||||||
piu = ProcessInfo(num_samples, max_length)
|
|
||||||
return piu.max_threads if piu.max_threads <= 3 else 3
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
data_dir = argv[1]
|
data_dir = argv[1]
|
||||||
csv_path = join(data_dir, "sequences.csv")
|
csv_path = join(data_dir, "sequences.csv")
|
||||||
fasta_path = join(data_dir, "sequences.fasta")
|
fasta_path = join(data_dir, "sequences.fasta")
|
||||||
|
|
||||||
print("Reading and processing files...")
|
print("\nReading and processing files...")
|
||||||
csv_table = CsvTable(csv_path).group_countries_by_median_length()
|
csv_table = CsvTable(csv_path).group_countries_by_median_length()
|
||||||
ids = csv_table.values('Accession')
|
ids = csv_table.values('Accession')
|
||||||
fasta_map = FastaMap(fasta_path).filter(lambda item: item[0] in ids)
|
fasta_map = FastaMap(fasta_path).filter(lambda item: item[0] in ids)
|
||||||
print("Files processing finished!")
|
print("Files processing finished!")
|
||||||
|
|
||||||
max_length = max(map(int, csv_table.values("Length")))
|
print("\nBuilding hierarchy...")
|
||||||
num_samples = len(fasta_map)
|
fasta_map.build_hierarchy()
|
||||||
max_threads_to_use = calculate_max_threads(max_length, num_samples)
|
|
||||||
|
|
||||||
print("Building hierarchy...")
|
|
||||||
fasta_map.build_hierarchy(max_threads_to_use)
|
|
||||||
print("Done!")
|
print("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from typing import List, Union, Dict
|
||||||
|
|
||||||
from prettytable import PrettyTable
|
from prettytable import PrettyTable
|
||||||
|
|
||||||
import utils.select
|
import utils.select as sel
|
||||||
|
|
||||||
|
|
||||||
class CsvTable:
|
class CsvTable:
|
||||||
|
|
@ -17,31 +17,27 @@ class CsvTable:
|
||||||
|
|
||||||
def __init__(self, arg: Union[str, List[Dict]]):
|
def __init__(self, arg: Union[str, List[Dict]]):
|
||||||
if isinstance(arg, str):
|
if isinstance(arg, str):
|
||||||
self.__table = self.__read(arg)
|
self._table = self.__read(arg)
|
||||||
elif isinstance(arg, list):
|
elif isinstance(arg, list):
|
||||||
self.__table = arg
|
self._table = arg
|
||||||
else:
|
else:
|
||||||
raise TypeError("Invalid Argument")
|
raise TypeError("Invalid Argument")
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
try:
|
try:
|
||||||
return self.__table[index]
|
return self._table[index]
|
||||||
except:
|
except:
|
||||||
raise IndexError("Index out of range")
|
raise IndexError("Index out of range")
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.__table)
|
return len(self._table)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
for row in self.__table:
|
for row in self._table:
|
||||||
yield row
|
yield row
|
||||||
|
|
||||||
@property
|
|
||||||
def table(self):
|
|
||||||
return self.__table
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
pretty_table = PrettyTable(list(self.__table[0].keys()))
|
pretty_table = PrettyTable(list(self._table[0].keys()))
|
||||||
for row in self:
|
for row in self:
|
||||||
pretty_table.add_row(row.values())
|
pretty_table.add_row(row.values())
|
||||||
return str(pretty_table)
|
return str(pretty_table)
|
||||||
|
|
@ -68,7 +64,7 @@ class CsvTable:
|
||||||
return CsvTable(filtered_data)
|
return CsvTable(filtered_data)
|
||||||
|
|
||||||
def __get_average_row(self, values: list) -> Union[dict, List[dict]]:
|
def __get_average_row(self, values: list) -> Union[dict, List[dict]]:
|
||||||
median_value = utils.select.quick_select_median(values, index=1)
|
median_value = sel.quick_select_median(values, index=1)
|
||||||
row = self[median_value[0]]
|
row = self[median_value[0]]
|
||||||
geo_location = row['Geo_Location']
|
geo_location = row['Geo_Location']
|
||||||
row['Geo_Location'] = geo_location.split(":")[0] \
|
row['Geo_Location'] = geo_location.split(":")[0] \
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
import collections
|
import collections
|
||||||
import time
|
import time
|
||||||
from typing import Tuple, Dict, List, Callable, Union, Any
|
from typing import Tuple, Dict, List, Callable, Union, Any
|
||||||
|
from utils.process_info import ProcessInfo
|
||||||
|
|
||||||
import libs.seqalign as sq
|
import libs.seqalign as sq
|
||||||
|
|
||||||
|
|
@ -59,7 +60,13 @@ class FastaMap:
|
||||||
data[rna_id] = rna
|
data[rna_id] = rna
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _compare_all_samples(self, threads):
|
def _compare_all_samples(self):
|
||||||
|
# Calculate the number of threads that can be
|
||||||
|
# used in order to speed up the comparisons
|
||||||
|
max_length = max(map(len, self.__data.values()))
|
||||||
|
num_samples = len(self.__data)
|
||||||
|
threads = ProcessInfo(num_samples, max_length).max_threads
|
||||||
|
# Start the comparisons
|
||||||
print("Performing comparisons...")
|
print("Performing comparisons...")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
ids = list(self.__data.keys())
|
ids = list(self.__data.keys())
|
||||||
|
|
@ -67,16 +74,17 @@ class FastaMap:
|
||||||
for i in range(len(ids) - 1)
|
for i in range(len(ids) - 1)
|
||||||
for j in range(i + 1, len(ids))]
|
for j in range(i + 1, len(ids))]
|
||||||
comparisons = sq.par_compare(to_compare, self.__data, str(threads))
|
comparisons = sq.par_compare(to_compare, self.__data, str(threads))
|
||||||
print(f"Comparisons performed in {time.time() - start_time:.3f} seconds!")
|
print(
|
||||||
|
f"Comparisons performed in {time.time() - start_time:.3f} seconds!")
|
||||||
return comparisons
|
return comparisons
|
||||||
|
|
||||||
def build_hierarchy(self, threads) -> List[Union[Tuple[Any, ...], list]]:
|
def build_hierarchy(self) -> List[Union[Tuple[Any, ...], list]]:
|
||||||
"""
|
"""
|
||||||
The function that is in charge of the comparison and the hierarchy of the samples
|
The function that is in charge of the comparison and the hierarchy of the samples
|
||||||
:param threads_option:
|
:param threads_option:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
comparisons = self._compare_all_samples(threads)
|
comparisons = self._compare_all_samples()
|
||||||
table = self._to_dict(comparisons)
|
table = self._to_dict(comparisons)
|
||||||
levels = [tuple(table.keys())]
|
levels = [tuple(table.keys())]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import psutil
|
||||||
|
|
||||||
class ProcessInfo:
|
class ProcessInfo:
|
||||||
"""
|
"""
|
||||||
informer and delimiter of system resources.
|
Calculates system resources.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_samples, max_length):
|
def __init__(self, num_samples, max_length):
|
||||||
|
|
@ -40,7 +40,7 @@ class ProcessInfo:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_threads(self):
|
def max_threads(self):
|
||||||
threads_available = self.num_logic_cores
|
threads_available = self.num_logic_cores if self.num_logic_cores <= 3 else 3
|
||||||
threads = math.floor(self.mem_available / self.max_mem_per_comparison)
|
threads = math.floor(self.mem_available / self.max_mem_per_comparison)
|
||||||
max_threads = threads if threads <= threads_available else threads_available
|
max_threads = threads if threads <= threads_available else threads_available
|
||||||
return max_threads if max_threads >= 1 else 1
|
return max_threads if max_threads >= 1 else 1
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
from src.python.utils.csv_table import CsvTable
|
from utils.csv_table import CsvTable
|
||||||
|
|
||||||
|
|
||||||
class TestCsvTable(TestCase):
|
class TestCsvTable(TestCase):
|
||||||
def test_values(self):
|
def test_values(self):
|
||||||
path = "../data/data_test/sequences.csv"
|
path = "data/data_test/sequences.csv"
|
||||||
accession = ["MT292569", "MT292570", "MT292571", "MT292572", "MT292574",
|
accession = ["MT292569", "MT292570", "MT292571", "MT292572", "MT292574",
|
||||||
"MT292575", "MT292576", "MT292573", "MT292577", "MT292578",
|
"MT292575", "MT292576", "MT292573", "MT292577", "MT292578",
|
||||||
"MT292579", "MT292580", "MT292581", "MT292582", "MT256917",
|
"MT292579", "MT292580", "MT292581", "MT292582", "MT256917",
|
||||||
|
|
@ -17,7 +17,7 @@ class TestCsvTable(TestCase):
|
||||||
self.assertTrue(all([x in accession for x in csv_table.values("Accession")]))
|
self.assertTrue(all([x in accession for x in csv_table.values("Accession")]))
|
||||||
|
|
||||||
def test_group_countries_by_median_length(self):
|
def test_group_countries_by_median_length(self):
|
||||||
path = "../data/data_test/sequences.csv"
|
path = "data/data_test/sequences.csv"
|
||||||
row_china = {"Accession": "MT259228",
|
row_china = {"Accession": "MT259228",
|
||||||
"Release_Date": "2020-03-30T00:00:00Z",
|
"Release_Date": "2020-03-30T00:00:00Z",
|
||||||
"Species": "Severe acute respiratory syndrome-related coronavirus",
|
"Species": "Severe acute respiratory syndrome-related coronavirus",
|
||||||
|
|
@ -39,4 +39,5 @@ class TestCsvTable(TestCase):
|
||||||
list_cases_test = CsvTable(list_cases)
|
list_cases_test = CsvTable(list_cases)
|
||||||
csv_table = CsvTable(path)
|
csv_table = CsvTable(path)
|
||||||
list_csv = csv_table.group_countries_by_median_length()
|
list_csv = csv_table.group_countries_by_median_length()
|
||||||
self.assertTrue(all(x in list_csv.table for x in list_cases_test.table))
|
self.assertTrue(all(x in list_csv._table for x in list_cases_test._table))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
import src.python.libs.seqalign as sq
|
import libs.seqalign as sq
|
||||||
|
|
||||||
|
|
||||||
class TestFastaMap(TestCase):
|
class TestFastaMap(TestCase):
|
||||||
|
|
@ -14,14 +14,17 @@ class TestFastaMap(TestCase):
|
||||||
seq2 = "AACG"
|
seq2 = "AACG"
|
||||||
expected = 2
|
expected = 2
|
||||||
self.assertEqual(expected, sq.compare_samples(seq1, seq2))
|
self.assertEqual(expected, sq.compare_samples(seq1, seq2))
|
||||||
|
|
||||||
seq1 = "JDSLFA"
|
seq1 = "JDSLFA"
|
||||||
seq2 = "JALFDSA"
|
seq2 = "JALFDSA"
|
||||||
expected = 6
|
expected = 6
|
||||||
self.assertEqual(expected, sq.compare_samples(seq1, seq2))
|
self.assertEqual(expected, sq.compare_samples(seq1, seq2))
|
||||||
|
|
||||||
seq1 = "DSAJKJFHJAKHDIOUZVJCXMVCZIOIUOWUQRIEUWQIPIDSFSDKZXV"
|
seq1 = "DSAJKJFHJAKHDIOUZVJCXMVCZIOIUOWUQRIEUWQIPIDSFSDKZXV"
|
||||||
seq2 = "FKSJAFKJIOTIGHLKJVMCXXZCMVLJASIRUWQOIUTPQWURIIPOQ"
|
seq2 = "FKSJAFKJIOTIGHLKJVMCXXZCMVLJASIRUWQOIUTPQWURIIPOQ"
|
||||||
expected = 46
|
expected = 46
|
||||||
self.assertEqual(expected, sq.compare_samples(seq1, seq2))
|
self.assertEqual(expected, sq.compare_samples(seq1, seq2))
|
||||||
|
|
||||||
seq1 = "KFDSKAJFJPOTIPWQUIMXZMVMZXBVM"
|
seq1 = "KFDSKAJFJPOTIPWQUIMXZMVMZXBVM"
|
||||||
seq2 = "FKDSPOIQTUITYSKZMV"
|
seq2 = "FKDSPOIQTUITYSKZMV"
|
||||||
expected = 31
|
expected = 31
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue