Add files via upload
This commit is contained in:
173
detect_yeast.py
Normal file
173
detect_yeast.py
Normal file
@ -0,0 +1,173 @@
|
||||
from keras.models import load_model
|
||||
from keras.utils import np_utils
|
||||
from util import *
|
||||
from segment_seed import *
|
||||
|
||||
|
||||
def cut_window(imageBF, center):
|
||||
r1 = int(center[0] - WINDOW_SHAPE[0] / 2)
|
||||
r2 = int(center[0] + WINDOW_SHAPE[0] / 2)
|
||||
c1 = int(center[1] - WINDOW_SHAPE[1] / 2)
|
||||
c2 = int(center[1] + WINDOW_SHAPE[1] / 2)
|
||||
return imageBF[r1:r2, c1:c2]
|
||||
|
||||
def windows_generator(imageBF, step_size):
|
||||
r_range = (int(WINDOW_SHAPE[0] / 2) + 1, IMAGE_SHAPE[0] - int(WINDOW_SHAPE[0] / 2) - 1)
|
||||
c_range = (int(WINDOW_SHAPE[1] / 2) + 1, IMAGE_SHAPE[1] - int(WINDOW_SHAPE[1] / 2) - 1)
|
||||
for r in range(r_range[0], r_range[1], step_size[0]):
|
||||
for c in range(c_range[0], c_range[1], step_size[1]):
|
||||
win = cut_window(imageBF, (r, c))
|
||||
center = (r, c)
|
||||
yield(win, center)
|
||||
|
||||
|
||||
def test_win_std(win):
|
||||
return win.std()/win.mean() < 0.1
|
||||
|
||||
|
||||
def test_stripes_std(win):
|
||||
r1 = int(WINDOW_SHAPE[0]/3)
|
||||
r2 = 2 * int(WINDOW_SHAPE[0]/3)
|
||||
c1 = int(WINDOW_SHAPE[1]/3)
|
||||
c2 = 2 * int(WINDOW_SHAPE[1]/3)
|
||||
if(win[r1:r2, :].std()/win[r1:r2, :].mean() < 0.1 or win[:, c1:c2].std()/win[:, c1:c2].mean() < 0.1):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def judge_yeast(win, model_detect):
|
||||
# filter out wrong windows using stdDev/mean within the window, if stdDev/mean<0.1, discard
|
||||
if(test_win_std(win)):
|
||||
return False
|
||||
# same as above, another way to filter out wrong windows
|
||||
elif(test_stripes_std(win)):
|
||||
return False
|
||||
else:
|
||||
im = win.reshape(1, WINDOW_SHAPE[0], WINDOW_SHAPE[1], 1)
|
||||
result = model_detect.predict(im)
|
||||
if(result[0, 0]==0.0 and result[0, 1]==1.0):
|
||||
return True
|
||||
elif(result[0, 0]==1.0 and result[0, 1]==0.0):
|
||||
return False
|
||||
|
||||
|
||||
def get_neighbor_list(center_list, center, neighbor_list):
|
||||
|
||||
pos_up = (center[0]-STEP_SIZE[0], center[1])
|
||||
pos_down = (center[0]+STEP_SIZE[0], center[1])
|
||||
pos_left = (center[0], center[1]-STEP_SIZE[1])
|
||||
pos_right = (center[0], center[1]+STEP_SIZE[1])
|
||||
poss = [pos_up, pos_down, pos_left, pos_right]
|
||||
# center_list.remove(center)
|
||||
neighbor_list.append(center)
|
||||
for pos in poss:
|
||||
if(pos in center_list and not pos in neighbor_list):
|
||||
get_neighbor_list(center_list, pos, neighbor_list)
|
||||
|
||||
|
||||
def get_neighbors(center_list, center):
|
||||
neighbors = []
|
||||
get_neighbor_list(center_list, center, neighbors)
|
||||
return list(set(neighbors))
|
||||
|
||||
|
||||
def merge_multi_detections(center_list):
|
||||
for center in center_list:
|
||||
center_list1 = center_list[:]
|
||||
neighbors = get_neighbors(center_list1, center)
|
||||
if(len(neighbors) > 1):
|
||||
for n in neighbors:
|
||||
center_list.remove(n)
|
||||
center_list.append(tuple(np.mean(np.array(neighbors), axis=0).astype(np.int32)))
|
||||
return center_list
|
||||
|
||||
|
||||
def detect_centers(imageBF, model_detect):
|
||||
center_list = list()
|
||||
for (win, center) in windows_generator(imageBF, STEP_SIZE):
|
||||
if(judge_yeast(win, model_detect) == True):
|
||||
center_list.append(center)
|
||||
# center_list = merge_multi_detections(center_list)
|
||||
return center_list
|
||||
|
||||
|
||||
def compute_vertex(win, model_rect):
|
||||
im = win.reshape(1, WINDOW_SHAPE[0], WINDOW_SHAPE[1], 1)/65535.
|
||||
vertex = model_rect.predict(im).astype(np.int32)
|
||||
return (vertex[0, 0], vertex[0, 1], vertex[0, 2], vertex[0, 3])
|
||||
|
||||
|
||||
def get_center_list(imageBF, model_detect):
|
||||
raw_center_list = detect_centers(imageBF, model_detect)
|
||||
center_list = raw_center_list #postprocessing of centers e.g. merge
|
||||
count = len(raw_center_list)
|
||||
new_count = 0
|
||||
|
||||
while(new_count != count):
|
||||
count = new_count
|
||||
center_list = merge_multi_detections(center_list)
|
||||
new_count = len(center_list)
|
||||
|
||||
return center_list
|
||||
|
||||
|
||||
def get_vertex_list(imageBF, model_detect, model_rect):
|
||||
center_list = get_center_list(imageBF, model_detect)
|
||||
vertex_list = list()
|
||||
for center in center_list:
|
||||
win = cut_window(imageBF, center)
|
||||
vertex = compute_vertex(win, model_rect)
|
||||
true_vertex = (vertex[0]+center[0]-WINDOW_SHAPE[0]//2,
|
||||
vertex[1]+center[0]-WINDOW_SHAPE[0]//2,
|
||||
vertex[2]+center[1]-WINDOW_SHAPE[1]//2,
|
||||
vertex[3]+center[1]-WINDOW_SHAPE[1]//2,)
|
||||
vertex_list.append(true_vertex)
|
||||
return vertex_list
|
||||
|
||||
|
||||
def get_polygon_list(image, center_list):
|
||||
(slopes, gradx, grady, slopes2, grad2x, grad2y, gradxy) = findslopes(image)
|
||||
polygon_list = list()
|
||||
for i in range(len(center_list)):
|
||||
print("processing %s" %i)
|
||||
polygon = get_polygon(image, gradx, grady, center_list[i])
|
||||
polygon_list.append(polygon)
|
||||
return polygon_list
|
||||
|
||||
def plot_detection_center(imageBF, center_list):
|
||||
colormap = mpl.cm.gray
|
||||
plt.imshow(imageBF, cmap=colormap)
|
||||
for center in center_list:
|
||||
plt.plot(center[1], center[0], 'r*')
|
||||
plt.xlim(0, 512)
|
||||
plt.ylim(512, 0)
|
||||
plt.show()
|
||||
|
||||
def plot_detection_rect(imageBF, vertex_list):
|
||||
colormap = mpl.cm.gray
|
||||
plt.imshow(imageBF, cmap=colormap)
|
||||
for (r1,r2,c1,c2) in vertex_list:
|
||||
plt.plot(np.ones(r2-r1)*c1, np.array(range(r1, r2)), 'r')
|
||||
plt.plot(np.ones(r2-r1)*c2, np.array(range(r1, r2)), 'r')
|
||||
plt.plot(np.array(range(c1, c2)), np.ones(c2-c1)*r1, 'r')
|
||||
plt.plot(np.array(range(c1, c2)), np.ones(c2-c1)*r2, 'r')
|
||||
plt.xlim(0, IMAGE_SHAPE[1])
|
||||
plt.ylim(IMAGE_SHAPE[0], 0)
|
||||
plt.show()
|
||||
|
||||
def plot_polygons(img, polygon_list):
|
||||
plt.imshow(img, cmap=mpl.cm.gray)
|
||||
for i in range(len(polygon_list)):
|
||||
plt.plot(polygon_list[i][:,1], polygon_list[i][:, 0], 'r')
|
||||
plt.xlim(0, IMAGE_SHAPE[1])
|
||||
plt.ylim(IMAGE_SHAPE[0], 0)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
image = np.array(Image.open('./examples/example1.tif'))
|
||||
model_detect = load_model('./models/CNN_detect6.h5')
|
||||
center_list = get_center_list(image, model_detect)
|
||||
polygon_list = get_polygon_list(image, center_list)
|
||||
plot_polygons(image, polygon_list)
|
Reference in New Issue
Block a user