174 lines
5.9 KiB
Python
174 lines
5.9 KiB
Python
from keras.models import load_model
|
|
from keras.utils import np_utils
|
|
from util import *
|
|
from segment_seed import *
|
|
|
|
|
|
def cut_window(imageBF, center):
|
|
r1 = int(center[0] - WINDOW_SHAPE[0] / 2)
|
|
r2 = int(center[0] + WINDOW_SHAPE[0] / 2)
|
|
c1 = int(center[1] - WINDOW_SHAPE[1] / 2)
|
|
c2 = int(center[1] + WINDOW_SHAPE[1] / 2)
|
|
return imageBF[r1:r2, c1:c2]
|
|
|
|
def windows_generator(imageBF, step_size):
|
|
r_range = (int(WINDOW_SHAPE[0] / 2) + 1, IMAGE_SHAPE[0] - int(WINDOW_SHAPE[0] / 2) - 1)
|
|
c_range = (int(WINDOW_SHAPE[1] / 2) + 1, IMAGE_SHAPE[1] - int(WINDOW_SHAPE[1] / 2) - 1)
|
|
for r in range(r_range[0], r_range[1], step_size[0]):
|
|
for c in range(c_range[0], c_range[1], step_size[1]):
|
|
win = cut_window(imageBF, (r, c))
|
|
center = (r, c)
|
|
yield(win, center)
|
|
|
|
|
|
def test_win_std(win):
|
|
return win.std()/win.mean() < 0.1
|
|
|
|
|
|
def test_stripes_std(win):
|
|
r1 = int(WINDOW_SHAPE[0]/3)
|
|
r2 = 2 * int(WINDOW_SHAPE[0]/3)
|
|
c1 = int(WINDOW_SHAPE[1]/3)
|
|
c2 = 2 * int(WINDOW_SHAPE[1]/3)
|
|
if(win[r1:r2, :].std()/win[r1:r2, :].mean() < 0.1 or win[:, c1:c2].std()/win[:, c1:c2].mean() < 0.1):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def judge_yeast(win, model_detect):
|
|
# filter out wrong windows using stdDev/mean within the window, if stdDev/mean<0.1, discard
|
|
if(test_win_std(win)):
|
|
return False
|
|
# same as above, another way to filter out wrong windows
|
|
elif(test_stripes_std(win)):
|
|
return False
|
|
else:
|
|
im = win.reshape(1, WINDOW_SHAPE[0], WINDOW_SHAPE[1], 1)
|
|
result = model_detect.predict(im)
|
|
if(result[0, 0]==0.0 and result[0, 1]==1.0):
|
|
return True
|
|
elif(result[0, 0]==1.0 and result[0, 1]==0.0):
|
|
return False
|
|
|
|
|
|
def get_neighbor_list(center_list, center, neighbor_list):
|
|
|
|
pos_up = (center[0]-STEP_SIZE[0], center[1])
|
|
pos_down = (center[0]+STEP_SIZE[0], center[1])
|
|
pos_left = (center[0], center[1]-STEP_SIZE[1])
|
|
pos_right = (center[0], center[1]+STEP_SIZE[1])
|
|
poss = [pos_up, pos_down, pos_left, pos_right]
|
|
# center_list.remove(center)
|
|
neighbor_list.append(center)
|
|
for pos in poss:
|
|
if(pos in center_list and not pos in neighbor_list):
|
|
get_neighbor_list(center_list, pos, neighbor_list)
|
|
|
|
|
|
def get_neighbors(center_list, center):
|
|
neighbors = []
|
|
get_neighbor_list(center_list, center, neighbors)
|
|
return list(set(neighbors))
|
|
|
|
|
|
def merge_multi_detections(center_list):
|
|
for center in center_list:
|
|
center_list1 = center_list[:]
|
|
neighbors = get_neighbors(center_list1, center)
|
|
if(len(neighbors) > 1):
|
|
for n in neighbors:
|
|
center_list.remove(n)
|
|
center_list.append(tuple(np.mean(np.array(neighbors), axis=0).astype(np.int32)))
|
|
return center_list
|
|
|
|
|
|
def detect_centers(imageBF, model_detect):
|
|
center_list = list()
|
|
for (win, center) in windows_generator(imageBF, STEP_SIZE):
|
|
if(judge_yeast(win, model_detect) == True):
|
|
center_list.append(center)
|
|
# center_list = merge_multi_detections(center_list)
|
|
return center_list
|
|
|
|
|
|
def compute_vertex(win, model_rect):
|
|
im = win.reshape(1, WINDOW_SHAPE[0], WINDOW_SHAPE[1], 1)/65535.
|
|
vertex = model_rect.predict(im).astype(np.int32)
|
|
return (vertex[0, 0], vertex[0, 1], vertex[0, 2], vertex[0, 3])
|
|
|
|
|
|
def get_center_list(imageBF, model_detect):
|
|
raw_center_list = detect_centers(imageBF, model_detect)
|
|
center_list = raw_center_list #postprocessing of centers e.g. merge
|
|
count = len(raw_center_list)
|
|
new_count = 0
|
|
|
|
while(new_count != count):
|
|
count = new_count
|
|
center_list = merge_multi_detections(center_list)
|
|
new_count = len(center_list)
|
|
|
|
return center_list
|
|
|
|
|
|
def get_vertex_list(imageBF, model_detect, model_rect):
|
|
center_list = get_center_list(imageBF, model_detect)
|
|
vertex_list = list()
|
|
for center in center_list:
|
|
win = cut_window(imageBF, center)
|
|
vertex = compute_vertex(win, model_rect)
|
|
true_vertex = (vertex[0]+center[0]-WINDOW_SHAPE[0]//2,
|
|
vertex[1]+center[0]-WINDOW_SHAPE[0]//2,
|
|
vertex[2]+center[1]-WINDOW_SHAPE[1]//2,
|
|
vertex[3]+center[1]-WINDOW_SHAPE[1]//2,)
|
|
vertex_list.append(true_vertex)
|
|
return vertex_list
|
|
|
|
|
|
def get_polygon_list(image, center_list):
|
|
(slopes, gradx, grady, slopes2, grad2x, grad2y, gradxy) = findslopes(image)
|
|
polygon_list = list()
|
|
for i in range(len(center_list)):
|
|
print("processing %s" %i)
|
|
polygon = get_polygon(image, gradx, grady, center_list[i])
|
|
polygon_list.append(polygon)
|
|
return polygon_list
|
|
|
|
def plot_detection_center(imageBF, center_list):
|
|
colormap = mpl.cm.gray
|
|
plt.imshow(imageBF, cmap=colormap)
|
|
for center in center_list:
|
|
plt.plot(center[1], center[0], 'r*')
|
|
plt.xlim(0, 512)
|
|
plt.ylim(512, 0)
|
|
plt.show()
|
|
|
|
def plot_detection_rect(imageBF, vertex_list):
|
|
colormap = mpl.cm.gray
|
|
plt.imshow(imageBF, cmap=colormap)
|
|
for (r1,r2,c1,c2) in vertex_list:
|
|
plt.plot(np.ones(r2-r1)*c1, np.array(range(r1, r2)), 'r')
|
|
plt.plot(np.ones(r2-r1)*c2, np.array(range(r1, r2)), 'r')
|
|
plt.plot(np.array(range(c1, c2)), np.ones(c2-c1)*r1, 'r')
|
|
plt.plot(np.array(range(c1, c2)), np.ones(c2-c1)*r2, 'r')
|
|
plt.xlim(0, IMAGE_SHAPE[1])
|
|
plt.ylim(IMAGE_SHAPE[0], 0)
|
|
plt.show()
|
|
|
|
def plot_polygons(img, polygon_list):
|
|
plt.imshow(img, cmap=mpl.cm.gray)
|
|
for i in range(len(polygon_list)):
|
|
plt.plot(polygon_list[i][:,1], polygon_list[i][:, 0], 'r')
|
|
plt.xlim(0, IMAGE_SHAPE[1])
|
|
plt.ylim(IMAGE_SHAPE[0], 0)
|
|
plt.show()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
image = np.array(Image.open('./examples/example1.tif'))
|
|
model_detect = load_model('./models/CNN_detect6.h5')
|
|
center_list = get_center_list(image, model_detect)
|
|
polygon_list = get_polygon_list(image, center_list)
|
|
plot_polygons(image, polygon_list)
|