package hirondelle.web4j.security;

import hirondelle.web4j.Controller;
import hirondelle.web4j.database.DAOException;
import hirondelle.web4j.database.SqlId;
import hirondelle.web4j.model.Id;
import hirondelle.web4j.util.Util;
import hirondelle.web4j.util.WebUtil;

import java.io.CharArrayWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Random;
import java.util.logging.Logger;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import javax.servlet.http.HttpSession;

/**
 Protect your application from a 
 <a href='http://en.wikipedia.org/wiki/Cross-site_request_forgery'>Cross Site Request Forgery</a> (CSRF).

 <P>Please see the package overview for important information regarding CSRF attacks, and security in general.
 
 <P>This filter maintains various items needed to protect against CSRF attacks. It acts both as a 
 pre-processor and as a post-processor. The behavior of this class is controlled by detecting two important events: 
 <ul>
  <li>the creation of new sessions (which does <i>not</i> necessarily imply a successful user login has also occured)
  <li>a successful user login (which <i>does</i> imply a session has also been created)
 </ul> 

 <h4>Pre-processing</h4>
 When <i>a new session</i> is detected (but not necessarily a user login), then this class will do the following :
 <ul>
 <li>calculate a random form-source id, and place it in session scope, under the key {@link #FORM_SOURCE_ID_KEY}. 
 This value is difficult to guess.
 <li>wrap the response in a custom wrapper, to implement the post-processing performed by this filter (see below)
 </ul>
 
 In addition, if <i>a new user login</i> is detected, then this class will do the following :
 <ul>
 <li>if there is any 'old' form-source id, place it in session scope as well, under the 
 key {@link #PREVIOUS_FORM_SOURCE_ID_KEY}. The 'old' form-source id is simply the form-source id 
 used in the <em>immediately preceding session for the same user</em>.
 <li>place in session scope an object which will store the form-source id when the session expires or is invalidated, under 
 the key {@link #FORM_SOURCE_DAO_KEY}.
 </ul> 
 
 <P>The above behavior of this class upon user login requires interaction with your database. 
 It's configured in <tt>web.xml</tt> using two items : 
 <tt>FormSourceIdRead</tt> and <tt>FormSourceIdWrite</tt>. These two items are 
 {@link hirondelle.web4j.database.SqlId} references. 
 They tell this class which SQL statements to use when reading and writing form-source ids 
 to the database. As usual, these {@link SqlId} items must be declared somewhere in your 
 application as <tt>public static final</tt> fields, and the corresponding SQL statements 
 must appear somewhere in an <tt>.sql</tt> file.
 
 <P>(Please see these items in the example application for an illustration : <tt>web.xml</tt>, 
 <tt>UserDAO</tt>, and <tt>csrf.sql</tt>.) 
 
 <h4>Post-processing</h4>
 If a session is present, then this class will use a custom response wrapper to alter the response:
 <ul>
 <li>if the response has <tt>content-type</tt> of <tt>text/html</tt> (or <tt>null</tt>), then scan 
 the response for all {@code <FORM>} tags with <tt>method='POST'</tt>. 
 <li>for each such {@code <FORM>} tag, add a hidden parameter in the following style :
<PRE>&lt;input type='hidden' name='web4j_key_for_form_source_id' value='151jdk65654dasdf545sadf6a5s4f'&gt;</PRE>
</ul>
 
 The name of the hidden parameter is taken from {@link #FORM_SOURCE_ID_KEY}, 
 and the <tt>value</tt> of that hidden parameter is the random token created during the pre-processing stage.

<h4>ApplicationFirewall</h4>
This class cooperates closely with {@link hirondelle.web4j.security.ApplicationFirewallImpl}. It is the 
firewall which performs the actual test to make sure the POSTed form came from your web app. 

 <h4>Warning Regarding Error Pages</h4>
 This Filter uses a wrapper for the response. When a Filter wraps the response, the error page 
 customization defined by <tt>web.xml</tt> will likely not function. 
 (This may be a defect of the Servlet API itself - see section 9.9.3.) That is, when an error occurs when using this 
 Filter, the generic error pages defined by the container may be served, instead of the custom 
 error pages you have configured in <tt>web.xml</tt>.
 
 <P>This filter will only affect the response if its content-type is <tt>text/html</tt> or <tt>null</tt>.
 It will not affect any other type of response.
*/
public class CsrfFilter implements Filter {

  /** 
   <em>Key</em> for item stored in session scope, and also <em>name</em> of hidden 
   request parameter added to POSTed forms.
   
   <P>Value - {@value}.
   <P>The <em>value</em> of this item is generated randomly for each new user login, and contains a 
   simple token that is hard to guess. Each POSTed form will be required by {@link ApplicationFirewallImpl} 
   to include a hidden parameter of this <em>name</em>, and the <em>value</em> of such hidden parameters 
   are matched to the corresponding item stored in session scope under the same key. These checks verify that  
   POSTed forms have come from a trusted source.
  */
  public static final String FORM_SOURCE_ID_KEY = "web4j_key_for_form_source_id";

  /** 
   Key for item stored in session scope.
   
   <P>Value - {@value}.
   <P>The value of this item is retrieved from the database for each new user login, and 
   represents the form-source id for the user's <em>immediately preceding</em> session. 
   When a match of form-source id against {@link #FORM_SOURCE_ID_KEY} fails, then a second 
   match is attempted against this item.
   
   <P>Please see the package description for an explanation of why this is necessary.
  */
  public static final String PREVIOUS_FORM_SOURCE_ID_KEY = "web4j_key_for_previous_form_source_id";
  
  /**
   Key for item stored in session scope.
     
   <P>Value - {@value}.
   <P>This item points to an {@link javax.servlet.http.HttpSessionBindingListener} object placed in each new session. 
   When the session ends, that object will be unbound from the session, and will save the user's current form-source id 
   to the database, for future use.  
  */
  public static final String FORM_SOURCE_DAO_KEY = "web4j_key_for_form_source_dao";

  /** 
   Read in filter configuration. 
   
   <P>Reads in {@link hirondelle.web4j.database.SqlId} references used to read and write the user's form-source id.
   <P>See class comment and package-level description for further information.
  */
  public void init(FilterConfig aFilterConfig)  {
    fLogger.config("INIT : " + this.getClass().getName() + ". Reading in SqlIds for reading and writing form-source ids.");
    String read_sql = aFilterConfig.getInitParameter("FormSourceIdRead");
    String write_sql = aFilterConfig.getInitParameter("FormSourceIdWrite");
    checkValidSqlId(read_sql);
    checkValidSqlId(write_sql);
    CsrfDAO.init(read_sql, write_sql);
  }
  
  /** This implementation does nothing.  */
  public void destroy() {
    fLogger.config("DESTROY : " + this.getClass().getName());
  }
  
  /**
   Protect against CSRF attacks.
  
   <P>See class comment and package-level description for further information.
  */
  public void doFilter(ServletRequest aRequest, ServletResponse aResponse, FilterChain aChain) throws IOException, ServletException {
    fLogger.fine("START CSRF Filter.");
    HttpServletRequest request = (HttpServletRequest)aRequest;
    HttpServletResponse response = (HttpServletResponse)aResponse;
    
    addItemsForNewSessions(request);
    
    if(isServingHtml(response)){
      fLogger.fine("Serving html. Wrapping response.");
      CharResponseWrapper wrapper = new CharResponseWrapper(response);
      aChain.doFilter(aRequest, wrapper); //AppFirewall and BadRequest
      
      CharArrayWriter buffer = new CharArrayWriter(); 
      CsrfModifiedResponse modifiedResponse = new CsrfModifiedResponse(request, response);
      String originalOutput = wrapper.toString();
      buffer.write(modifiedResponse.addNonceTo(originalOutput));
      String encoding = (String)WebUtil.findAttribute(Controller.CHARACTER_ENCODING, request);
      aResponse.setContentLength(buffer.toString().getBytes(encoding).length);
      
      aResponse.getWriter().write(buffer.toString()); //this will use the response's encoding
      aResponse.getWriter().close();
    }
    else {
      fLogger.fine("Not serving html. Not modifiying response.");
      aChain.doFilter(aRequest, aResponse); //do nothing special
    }
    fLogger.fine("END CSRF Filter.");
  }
  
  /**
   Add a CSRF token to an existing session <i>that has no user login</i>.
   
   <P><i>This method is called only when a session created by an Action, instead of the usual login mechanism.</i>
   See {@link hirondelle.web4j.action.ActionImpl#createSessionAndCsrfToken()} for important information.  
  */
  public void addCsrfToken(HttpServletRequest aRequest) throws ServletException {
    addItemsForNewSessions(aRequest);
  }
  
  // PRIVATE
  
  //WARNING : Filters always need to be thread-safe !!
  
  private static final Logger fLogger = Util.getLogger(CsrfFilter.class);
  private static final boolean DO_NOT_CREATE = false;
  
  private static void checkValidSqlId(String aSqlId){
    if ( ! Util.textHasContent(aSqlId) ) {
      String message = "SqlId required as Filter init-param, but has no content: " + Util.quote(aSqlId); 
      fLogger.severe(message);
    }
  }
  
  private void addItemsForNewSessions(HttpServletRequest aRequest) throws ServletException {
    HttpSession session = aRequest.getSession(DO_NOT_CREATE);
    if ( sessionExists(session) ){
      if ( hasNoFormSourceIdInSession(session) ){
        Id currentFormSourceId = calcFormSourceId();
        addFormSourceIdToSession(session, currentFormSourceId);
        if( userHasLoggedIn(aRequest) ){
          CsrfDAO formSourceDAO = new CsrfDAO(aRequest.getUserPrincipal().getName(), currentFormSourceId);
          addPreviousFormSourceIdToSession(session, formSourceDAO);          
          addFormSourceDAOToSession(session, formSourceDAO);
        }
      }
    }
  }
  
  private boolean sessionExists(HttpSession aSession){
    return aSession != null;
  }
  
  private boolean hasNoFormSourceIdInSession(HttpSession aSession){
    return aSession.getAttribute(FORM_SOURCE_ID_KEY) == null;
  }

  private boolean userHasLoggedIn(HttpServletRequest aRequest){
    return aRequest.getUserPrincipal() != null;
  }
  
  private void addFormSourceIdToSession(HttpSession aSession, Id aCurrentFormSourceId) {
    fLogger.fine("Adding new form-source id to user's session.");
    aSession.setAttribute(FORM_SOURCE_ID_KEY, aCurrentFormSourceId);
  }
  
  private Id calcFormSourceId(){
    String token = getHashFor( getRandomNumber().toString() );
    return new Id(token);    
  }
  
  private void addPreviousFormSourceIdToSession(HttpSession aSession, CsrfDAO aDAO) throws ServletException {
    fLogger.fine("Adding previous form-source id to session.");
    try {
      Id previousFormSourceId = aDAO.fetchPreviousFormSourceId();
      if( previousFormSourceId == null ) {
        fLogger.fine("No previous form-source id found.");
      }
      else {
        fLogger.fine("Adding previous form-source id to session.");
        aSession.setAttribute(PREVIOUS_FORM_SOURCE_ID_KEY, previousFormSourceId);
      }
    }
    catch (DAOException ex){
      throw new ServletException("Cannot fetch previous form-source id from database.", ex);
    }
  }
  
  private void  addFormSourceDAOToSession(HttpSession aSession, CsrfDAO aDAO) {
    fLogger.fine("Adding CsrfDAO object to session.");
    aSession.setAttribute(FORM_SOURCE_DAO_KEY, aDAO);
  }
  
  private synchronized Long getRandomNumber() {
    Random random = new Random();
    return random.nextLong();
  }
  
  private String getHashFor(String aText) {
    String result = null;
    try {
      MessageDigest sha = MessageDigest.getInstance("SHA-1");
      byte[] hashOne = sha.digest(aText.getBytes());
      result = hexEncode(hashOne);
    }
    catch (NoSuchAlgorithmException ex){
      String message = "MessageDigest cannot find SHA-1 algorithm."; 
      fLogger.severe(message);
      throw new RuntimeException(message);
    }
    return result;
  }
  
  /**
   The byte[] returned by MessageDigest does not have a nice
   textual representation, so some form of encoding is usually performed.
  
   This implementation follows the example of David Flanagan's book
   "Java In A Nutshell", and converts a byte array into a String
   of hex characters.
  */
  static private String hexEncode( byte[] aInput){
    StringBuilder result = new StringBuilder();
    char[] digits = {'0', '1', '2', '3', '4','5','6','7','8','9','a','b','c','d','e','f'};
    for (int idx = 0; idx < aInput.length; ++idx) {
      byte b = aInput[idx];
      result.append( digits[ (b&0xf0) >> 4 ] );
      result.append( digits[ b&0x0f] );
    }
    return result.toString();
  }

  private static final String TEXT_HTML = "text/html";
  
  /** Return true if content-type of reponse is null, or starts with 'text/html' (case-sensitive).  */
  private boolean isServingHtml(HttpServletResponse aResponse){
    String contentType = aResponse.getContentType();
    boolean missingContentType = ! Util.textHasContent(contentType);
    boolean startsWithHTML = Util.textHasContent(contentType) && contentType.startsWith(TEXT_HTML);
    return missingContentType || startsWithHTML;
  }
  
  private static final class CharResponseWrapper extends HttpServletResponseWrapper {
    public String toString() {
        return fOutput.toString();
    }
    public CharResponseWrapper(HttpServletResponse response){
        super(response);
        fOutput = new CharArrayWriter();
    }
    public PrintWriter getWriter(){
        return new PrintWriter(fOutput);
    }
    private CharArrayWriter fOutput;
  }
}
