# -*- coding: utf-8 -*-
"""
Created on Wed Jan 18 13:28:00 2016
@author: Xiaohui Zhu
"""
from numpy import fft
import numpy as np
import pylab as plt
from PIL import Image

# Read the ptychography image
image_path='c:\\Downloads\\FRC_ALS\\' # Image location
image_name='fuel-cell-8bit' # Image name
im=Image.open(image_path+image_name+'.tif')
print "Size of the image is: "+ str(im.size)

# split the original images into two subsets

# get the even subset
half_even=Image.new('L',tuple([int(d/2) for d in im.size]))
try:
    for i in range(im.size[0]):
        for j in range(im.size[1]):
            if i%2==0:
                half_even.putpixel((int(i/2),int(j/2)),im.getpixel((i,j)))
    half_even.save(image_path+'half_even-'+image_name+'.tif')
except IndexError:
    print 'Note: they are not both even numbers'
    
# get the odd subset
try:    
    half_odd=Image.new('L',tuple([int(d/2) for d in im.size]))
    for i in range(im.size[0]):
        for j in range(im.size[1]):
            if i%2!=0:
                half_odd.putpixel((int(i/2),int(j/2)),im.getpixel((i,j)))
    half_odd.save(image_path+'half_odd-'+image_name+'.tif')
except IndexError:
    print 'Note: they are not both even numbers'
        
## Read the odd and even subsets
image1 = np.roll(half_odd,0,axis=0)
image2 = np.roll(half_even,1,axis=1)

#  Get the Nyquist frequency
ny,nx=image1.shape
if nx>ny:
    image1=image1[0:ny,0:ny]
    image2=image2[0:ny,0:ny]
    maxSize=ny
else:
    image1=image1[0:nx,0:nx]
    image2=image2[0:nx,0:nx]
    maxSize=nx
freq_nyq=int(np.floor(maxSize/2.0))

##  Create Fourier grid
x = np.arange( -np.floor( maxSize/2.0 ) , np.ceil( maxSize/2.0 ) )
y = np.arange( -np.floor( maxSize/2.0 ) , np.ceil( maxSize/2.0 ) )
x,y = np.meshgrid( x , y )
map_dist = np.sqrt( x*x + y*y )

##  FFT transforms of the input images
fft_image1=fft.fftshift(fft.fftn(image1))
fft_image2=fft.fftshift(fft.fftn(image2))
    
## Smooth the curve
def savitzky_golay(y, window_size, order, deriv=0, rate=1):
    from math import factorial
    try:
        window_size = np.abs(np.int(window_size))
        order = np.abs(np.int(order))
    except ValueError:
        raise ValueError("window_size and order have to be of type int")
    if window_size % 2 != 1 or window_size < 1:
        raise TypeError("window_size size must be a positive odd number")
    if window_size < order + 2:
        raise TypeError("window_size is too small for the polynomials order")
    order_range = range(order+1)
    half_window = (window_size -1) // 2
    # precompute coefficients
    b = np.mat([[k**i for i in order_range] for k in range(-half_window, half_window+1)])
    m = np.linalg.pinv(b).A[deriv] * rate**deriv * factorial(deriv)
    # pad the signal at the extremes with
    # values taken from the signal itself
    firstvals = y[0] - np.abs( y[1:half_window+1][::-1] - y[0] )
    lastvals = y[-1] + np.abs(y[-half_window-1:-1][::-1] - y[-1])
    y = np.concatenate((firstvals, y, lastvals))
    return np.convolve( m[::-1], y, mode='valid')

##  Get thickness, and size of the rings
width_ring = 0.8
r = 3.0

def find_ring_area( map_dist, r, width_ring):
    check_r=map_dist>=r
    check_r_width=map_dist<=r+width_ring
    ring_area = np.argwhere( check_r * check_r_width)
    return ring_area
    
##  Calculate FRC 
C1 = []
C2 = []  
C3 = []
n = []   

while r + width_ring < freq_nyq :
    ring = find_ring_area( map_dist , r , width_ring )
    aux1 = fft_image1[ring[:,0],ring[:,1]]
    aux2 = fft_image2[ring[:,0],ring[:,1]]
    #FRC=aux1* conjugate(aux2)/(aux1)**2 * aux2**2
    C1.append( np.sum( aux1 * np.conjugate(aux2) ) )
    C2.append( np.sum( np.abs( aux1 )**2 ) )
    C3.append( np.sum( np.abs( aux2 )**2 ) )
    n.append(len(aux1))
    r += width_ring
    n = np.array( n )
    n=list(n)
    FRC = np.abs( np.array( C1 ) )/ np.sqrt( np.array( C2 ) * np.array( C3 ) )
    FRC = savitzky_golay(FRC, 7, 1)
half_bit=[]
for i in n:
    y= (0.2071 + 1.9102 / np.sqrt( i ) ) / ( 1.2071 + 0.9102 / np.sqrt( i ) )
    half_bit.append(y)
half_bit=np.array(half_bit)

# Pixel size read from ptycho image:   
pixnm = np.float(8.3) #Change it for different images 
nbins=len(FRC)
freqs = np.linspace(0,1.414 / pixnm, nbins)  
half_height = np.ones((nbins)) * 0.5
plt.plot(freqs,FRC,marker='o',mfc='blue',mec='white',lw=1,linestyle='-',color='black')
plt.plot(freqs,half_height,linestyle = '--', color = 'grey')
plt.plot(freqs,half_bit,'r-')
axes = plt.gca()
axes.set_xlim([0,0.18])
axes.set_ylim([0,1.05])
plt.show()
f = open(image_name+'-frc.txt','w')
for i in range(nbins - 1):
    f.write(str(FRC[i]) + '\n')
f.close()
f = open(image_name+'-freqs.txt','w')
for i in range(nbins - 1):
    f.write(str(freqs[i]) + '\n')
f.close()
f = open(image_name+'-half_height.txt','w')
for i in range(nbins - 1):
    f.write(str(half_height[i]) + '\n')
f.close()
f = open(image_name+'-half_bit.txt','w')
for i in range(nbins - 1):
    f.write(str(half_bit[i]) + '\n')
f.close()
im.show() # Show which image processed