hisat-3n/evaluation/simulation/init.py

344 lines
12 KiB
Python
Raw Permalink Normal View History

2025-01-18 13:09:52 +00:00
#!/usr/bin/env python
import sys, os
import string, re
use_message = '''
'''
def extract_pair(RNA):
read_dic = {}
pair_reported = set()
out_file = open("sim_paired.sam", "w")
hits_file = open("sim.sam")
for line in hits_file:
if line.startswith('@'):
continue
fields = line[:-1].split()
read_id, flag, chr1, pos1, mapQ, cigar1, chr2, pos2 = fields[:8]
flag = int(flag)
assert flag & 0x4 == 0
pos1, pos2 = int(pos1), int(pos2)
TI, NM1, Zs1 = "", "", ""
for field in fields[11:]:
if field.startswith("TI"):
TI = "\t" + field
elif field.startswith("NM"):
NM1 = "\t" + field
elif field.startswith("Zs"):
Zs1 = "\t" + field
assert NM1 != ""
assert chr1 == chr2
me = "%s\t%s\t%d" % (read_id, chr1, pos1)
partner = "%s\t%s\t%d" % (read_id, chr2, pos2)
if partner in read_dic:
maps = read_dic[partner]
for map in maps:
if map[0] == me:
cigar2, NM2, Zs2 = map[1:4]
if int(pos2) > int(pos1):
p_str = "%s\t%s\t%d\t%s\t%s\t%d\t%s%s%s%s%s%s" % \
(read_id, chr1, pos1, cigar1, chr2, pos2, cigar2, TI, NM1, NM2, Zs1, Zs2)
else:
p_str = "%s\t%s\t%d\t%s\t%s\t%d\t%s%s%s%s%s%s" % \
(read_id, chr2, pos2, cigar2, chr1, pos1, cigar1, TI, NM2, NM1, Zs2, Zs1)
if p_str not in pair_reported:
pair_reported.add(p_str)
print >> out_file, p_str
if not me in read_dic:
read_dic[me] = []
read_dic[me].append([partner, cigar1, NM1, Zs1])
hits_file.close()
out_file.close()
def init_reads(read_dir, RNA):
sim1_file = open("sim_1.sam", 'w')
sim2_file = open("sim_2.sam", 'w')
for line in open("sim.sam"):
if line.startswith('@'):
continue
fields = line[:-1].split('\t')
read_id, flag, chr, pos, _, cigar = fields[:6]
left_read = int(flag) < 128
NM, TI, Zs = "", "", ""
for field in fields:
if field.startswith("NM"):
NM = "\t" + field
elif field.startswith("TI"):
TI = "\t" + field
elif field.startswith("Zs"):
Zs = "\t" + field
print >> (sim1_file if left_read else sim2_file), \
"%s\t%s\t%s\t%s%s%s%s" % (read_id, chr, pos, cigar, TI, NM, Zs)
sim1_file.close()
sim2_file.close()
extract_pair(RNA)
def to_junction_str(junction):
return "%s-%d-%d" % (junction[0], junction[1], junction[2])
def to_junction(junction_str):
fields = junction_str.split("-")
if len(fields) > 3:
chr, left, right = "-".join(fields[:-2]), fields[-2], fields[-1]
else:
assert len(fields) == 3
chr, left, right = fields
return [chr, int(left), int(right)]
def junction_cmp(a, b):
if a[0] != b[0]:
if a[0] < b[0]:
return -1
else:
return 1
if a[1] != b[1]:
if a[1] < b[1]:
return -1
else:
return 1
if a[2] != b[2]:
if a[2] < b[2]:
return -1
else:
return 1
return 0
def classify_reads(RNA):
if RNA:
readtypes = ["all", "M", "2M_gt_15", "2M_8_15", "2M_1_7", "gt_2M"]
else:
readtypes = ["all" ,"M"]
readtype_order = {}
for i in range(len(readtypes)):
readtype = readtypes[i]
assert readtype not in readtype_order
readtype_order[readtype] = i
for paired in [False, True]:
for readtype in readtypes:
if paired:
base_fname = "sim_paired"
type_sam_fname = base_fname + "_" + readtype + ".sam"
type_read1_fname = base_fname + "_1_" + readtype + ".fa"
type_read2_fname = base_fname + "_2_" + readtype + ".fa"
type_junction_fname = base_fname + "_" + readtype + ".junc"
else:
base_fname = "sim_single"
type_sam_fname = base_fname + "_" + readtype + ".sam"
type_read1_fname = base_fname + "_" + readtype + ".fa"
type_read2_fname = ""
type_junction_fname = base_fname + "_" + readtype + ".junc"
if os.path.exists(type_sam_fname) and \
os.path.exists(type_junction_fname):
continue
read_ids = set()
junctions = []
type_sam_file = open(type_sam_fname, "w")
if paired:
sam_file = open("sim_paired.sam")
else:
sam_file = open("sim_1.sam")
cigar_re = re.compile('\d+\w')
def get_read_type(cigar):
cigars = cigar_re.findall(cigar)
def get_cigar_chars_MN(cigars):
cigars = cigar_re.findall(cigars)
cigar_chars = ""
for cigar in cigars:
cigar_op = cigar[-1]
if cigar_op in "MN":
if cigar_chars == "" or cigar_chars[-1] != cigar_op:
cigar_chars += cigar_op
return cigar_chars
cigar_str = get_cigar_chars_MN(cigar)
if cigar_str == "M":
return cigar_str
elif cigar_str == "MNM":
assert len(cigars) >= 3
left_anchor = 0
for ci in range(len(cigars)):
c = cigars[ci][-1]
c_len = int(cigars[ci][:-1])
if c in "MI":
left_anchor += c_len
else:
break
assert left_anchor > 0
right_anchor = 0
for ci in reversed(range(len(cigars))):
c = cigars[ci][-1]
c_len = int(cigars[ci][:-1])
if c in "MI":
right_anchor += c_len
else:
break
assert right_anchor > 0
min_anchor = min(left_anchor, right_anchor)
if min_anchor > 15:
return "2M_gt_15"
elif min_anchor >= 8 and min_anchor <= 15:
return "2M_8_15"
elif min_anchor >= 1 and min_anchor <= 7:
return "2M_1_7"
else:
assert False
else:
assert cigar_str not in ["M", "MNM"]
return "gt_2M"
# chr and pos are assumed to be integers
def get_junctions(chr, pos, cigar_str):
junctions = []
right_pos = pos
cigars = cigar_re.findall(cigar_str)
cigars = [[int(cigars[i][:-1]), cigars[i][-1]] for i in range(len(cigars))]
for i in range(len(cigars)):
length, cigar_op = cigars[i]
if cigar_op == "N":
left, right = right_pos - 1, right_pos + length
if i > 0 and cigars[i-1][1] in "ID":
if cigars[i-1][1] == "I":
left += cigars[i-1][0]
else:
left -= cigars[i-1][0]
if i + 1 < len(cigars) and cigars[i+1][1] in "ID":
if cigars[i+1][1] == "I":
right -= cigars[i+1][0]
else:
right += cigars[i+1][0]
junctions.append([chr, left, right])
if cigar_op in "MND":
right_pos += length
return junctions
for line in sam_file:
fields = line[:-1].split()
if paired:
read_id, chr, pos, cigar, chr2, pos2, cigar2 = fields[:7]
else:
read_id, chr, pos, cigar = fields[:4]
read_id = int(read_id)
readtype2 = get_read_type(cigar)
if paired:
readtype3 = get_read_type(cigar2)
assert readtype2 in readtype_order
assert readtype3 in readtype_order
if readtype_order[readtype2] < readtype_order[readtype3]:
readtype2 = readtype3
if readtype == "all" or readtype == readtype2:
read_ids.add(read_id)
print >> type_sam_file, line[:-1]
junctions += get_junctions(chr, int(pos), cigar)
if paired:
junctions += get_junctions(chr2, int(pos2), cigar2)
sam_file.close()
type_sam_file.close()
# make this set non-redundant
junction_set = set()
for junction in junctions:
junction_set.add(to_junction_str(junction))
junctions = []
for junction_str in junction_set:
junctions.append(to_junction(junction_str))
# sort the list of junctions
junctions = sorted(junctions, cmp=junction_cmp)
# write the junctions into a file
type_junction_file = open(type_junction_fname, "w")
for junction in junctions:
print >> type_junction_file, "%s\t%d\t%d" % (junction[0], junction[1], junction[2])
type_junction_file.close()
def write_reads(read_fname, type_read_fname):
type_read_file = open(type_read_fname, "w")
read_file = open(read_fname)
write = False
for line in read_file:
if line[0] == ">":
read_id = int(line[1:-1])
write = read_id in read_ids
if write:
print >> type_read_file, line[:-1]
read_file.close()
type_read_file.close()
if paired:
write_reads("sim_1.fa", type_read1_fname)
write_reads("sim_2.fa", type_read2_fname)
else:
write_reads("sim_1.fa", type_read1_fname)
def init():
read_dir_base = "../reads/simulation/"
read_dirs = os.listdir(read_dir_base)
for read_dir in read_dirs:
if os.path.exists(read_dir):
continue
if not os.path.exists(read_dir_base + read_dir + "/sim.sam") or \
not os.path.exists(read_dir_base + read_dir + "/sim_1.fa") or \
not os.path.exists(read_dir_base + read_dir + "/sim_2.fa"):
continue
print >> sys.stderr, "Processing", read_dir, "..."
os.mkdir(read_dir)
os.chdir(read_dir)
os.system("ln -s ../%s%s/* ." % (read_dir_base, read_dir))
os.system("ln -s sim_1.fa 1.fa")
os.system("ln -s sim_2.fa 2.fa")
RNA = (read_dir.find("RNA") != -1)
init_reads(read_dir, RNA)
classify_reads(RNA)
os.system("ln -s ../calculate_read_cost.py .")
os.chdir("..")
if __name__ == "__main__":
init()