001    package hirondelle.web4j.security;
002    
003    import hirondelle.web4j.Controller;
004    import hirondelle.web4j.database.DAOException;
005    import hirondelle.web4j.database.SqlId;
006    import hirondelle.web4j.model.Id;
007    import hirondelle.web4j.util.Util;
008    import hirondelle.web4j.util.WebUtil;
009    
010    import java.io.CharArrayWriter;
011    import java.io.IOException;
012    import java.io.PrintWriter;
013    import java.security.MessageDigest;
014    import java.security.NoSuchAlgorithmException;
015    import java.util.Random;
016    import java.util.logging.Logger;
017    
018    import javax.servlet.Filter;
019    import javax.servlet.FilterChain;
020    import javax.servlet.FilterConfig;
021    import javax.servlet.ServletException;
022    import javax.servlet.ServletRequest;
023    import javax.servlet.ServletResponse;
024    import javax.servlet.http.HttpServletRequest;
025    import javax.servlet.http.HttpServletResponse;
026    import javax.servlet.http.HttpServletResponseWrapper;
027    import javax.servlet.http.HttpSession;
028    
029    /**
030     Protect your application from a 
031     <a href='http://en.wikipedia.org/wiki/Cross-site_request_forgery'>Cross Site Request Forgery</a> (CSRF).
032    
033     <P>Please see the package overview for important information regarding CSRF attacks, and security in general.
034     
035     <P>This filter maintains various items needed to protect against CSRF attacks. It acts both as a 
036     pre-processor and as a post-processor. The behavior of this class is controlled by detecting two important events: 
037     <ul>
038      <li>the creation of new sessions (which does <i>not</i> necessarily imply a successful user login has also occured)
039      <li>a successful user login (which <i>does</i> imply a session has also been created)
040     </ul> 
041    
042     <h4>Pre-processing</h4>
043     When <i>a new session</i> is detected (but not necessarily a user login), then this class will do the following :
044     <ul>
045     <li>calculate a random form-source id, and place it in session scope, under the key {@link #FORM_SOURCE_ID_KEY}. 
046     This value is difficult to guess.
047     <li>wrap the response in a custom wrapper, to implement the post-processing performed by this filter (see below)
048     </ul>
049     
050     In addition, if <i>a new user login</i> is detected, then this class will do the following :
051     <ul>
052     <li>if there is any 'old' form-source id, place it in session scope as well, under the 
053     key {@link #PREVIOUS_FORM_SOURCE_ID_KEY}. The 'old' form-source id is simply the form-source id 
054     used in the <em>immediately preceding session for the same user</em>.
055     <li>place in session scope an object which will store the form-source id when the session expires or is invalidated, under 
056     the key {@link #FORM_SOURCE_DAO_KEY}.
057     </ul> 
058     
059     <P>The above behavior of this class upon user login requires interaction with your database. 
060     It's configured in <tt>web.xml</tt> using two items : 
061     <tt>FormSourceIdRead</tt> and <tt>FormSourceIdWrite</tt>. These two items are 
062     {@link hirondelle.web4j.database.SqlId} references. 
063     They tell this class which SQL statements to use when reading and writing form-source ids 
064     to the database. As usual, these {@link SqlId} items must be declared somewhere in your 
065     application as <tt>public static final</tt> fields, and the corresponding SQL statements 
066     must appear somewhere in an <tt>.sql</tt> file.
067     
068     <P>(Please see these items in the example application for an illustration : <tt>web.xml</tt>, 
069     <tt>UserDAO</tt>, and <tt>csrf.sql</tt>.) 
070     
071     <h4>Post-processing</h4>
072     If a session is present, then this class will use a custom response wrapper to alter the response:
073     <ul>
074     <li>if the response has <tt>content-type</tt> of <tt>text/html</tt> (or <tt>null</tt>), then scan 
075     the response for all {@code <FORM>} tags with <tt>method='POST'</tt>. 
076     <li>for each such {@code <FORM>} tag, add a hidden parameter in the following style :
077    <PRE>&lt;input type='hidden' name='web4j_key_for_form_source_id' value='151jdk65654dasdf545sadf6a5s4f'&gt;</PRE>
078    </ul>
079     
080     The name of the hidden parameter is taken from {@link #FORM_SOURCE_ID_KEY}, 
081     and the <tt>value</tt> of that hidden parameter is the random token created during the pre-processing stage.
082    
083    <h4>ApplicationFirewall</h4>
084    This class cooperates closely with {@link hirondelle.web4j.security.ApplicationFirewallImpl}. It is the 
085    firewall which performs the actual test to make sure the POSTed form came from your web app. 
086    
087     <h4>Warning Regarding Error Pages</h4>
088     This Filter uses a wrapper for the response. When a Filter wraps the response, the error page 
089     customization defined by <tt>web.xml</tt> will likely not function. 
090     (This may be a defect of the Servlet API itself - see section 9.9.3.) That is, when an error occurs when using this 
091     Filter, the generic error pages defined by the container may be served, instead of the custom 
092     error pages you have configured in <tt>web.xml</tt>.
093     
094     <P>This filter will only affect the response if its content-type is <tt>text/html</tt> or <tt>null</tt>.
095     It will not affect any other type of response.
096    */
097    public class CsrfFilter implements Filter {
098    
099      /** 
100       <em>Key</em> for item stored in session scope, and also <em>name</em> of hidden 
101       request parameter added to POSTed forms.
102       
103       <P>Value - {@value}.
104       <P>The <em>value</em> of this item is generated randomly for each new user login, and contains a 
105       simple token that is hard to guess. Each POSTed form will be required by {@link ApplicationFirewallImpl} 
106       to include a hidden parameter of this <em>name</em>, and the <em>value</em> of such hidden parameters 
107       are matched to the corresponding item stored in session scope under the same key. These checks verify that  
108       POSTed forms have come from a trusted source.
109      */
110      public static final String FORM_SOURCE_ID_KEY = "web4j_key_for_form_source_id";
111    
112      /** 
113       Key for item stored in session scope.
114       
115       <P>Value - {@value}.
116       <P>The value of this item is retrieved from the database for each new user login, and 
117       represents the form-source id for the user's <em>immediately preceding</em> session. 
118       When a match of form-source id against {@link #FORM_SOURCE_ID_KEY} fails, then a second 
119       match is attempted against this item.
120       
121       <P>Please see the package description for an explanation of why this is necessary.
122      */
123      public static final String PREVIOUS_FORM_SOURCE_ID_KEY = "web4j_key_for_previous_form_source_id";
124      
125      /**
126       Key for item stored in session scope.
127         
128       <P>Value - {@value}.
129       <P>This item points to an {@link javax.servlet.http.HttpSessionBindingListener} object placed in each new session. 
130       When the session ends, that object will be unbound from the session, and will save the user's current form-source id 
131       to the database, for future use.  
132      */
133      public static final String FORM_SOURCE_DAO_KEY = "web4j_key_for_form_source_dao";
134    
135      /** 
136       Read in filter configuration. 
137       
138       <P>Reads in {@link hirondelle.web4j.database.SqlId} references used to read and write the user's form-source id.
139       <P>See class comment and package-level description for further information.
140      */
141      public void init(FilterConfig aFilterConfig)  {
142        fLogger.config("INIT : " + this.getClass().getName() + ". Reading in SqlIds for reading and writing form-source ids.");
143        String read_sql = aFilterConfig.getInitParameter("FormSourceIdRead");
144        String write_sql = aFilterConfig.getInitParameter("FormSourceIdWrite");
145        checkValidSqlId(read_sql);
146        checkValidSqlId(write_sql);
147        CsrfDAO.init(read_sql, write_sql);
148      }
149      
150      /** This implementation does nothing.  */
151      public void destroy() {
152        fLogger.config("DESTROY : " + this.getClass().getName());
153      }
154      
155      /**
156       Protect against CSRF attacks.
157      
158       <P>See class comment and package-level description for further information.
159      */
160      public void doFilter(ServletRequest aRequest, ServletResponse aResponse, FilterChain aChain) throws IOException, ServletException {
161        fLogger.fine("START CSRF Filter.");
162        HttpServletRequest request = (HttpServletRequest)aRequest;
163        HttpServletResponse response = (HttpServletResponse)aResponse;
164        
165        addItemsForNewSessions(request);
166        
167        if(isServingHtml(response)){
168          fLogger.fine("Serving html. Wrapping response.");
169          CharResponseWrapper wrapper = new CharResponseWrapper(response);
170          aChain.doFilter(aRequest, wrapper); //AppFirewall and BadRequest
171          
172          CharArrayWriter buffer = new CharArrayWriter(); 
173          CsrfModifiedResponse modifiedResponse = new CsrfModifiedResponse(request, response);
174          String originalOutput = wrapper.toString();
175          buffer.write(modifiedResponse.addNonceTo(originalOutput));
176          String encoding = (String)WebUtil.findAttribute(Controller.CHARACTER_ENCODING, request);
177          aResponse.setContentLength(buffer.toString().getBytes(encoding).length);
178          
179          aResponse.getWriter().write(buffer.toString()); //this will use the response's encoding
180          aResponse.getWriter().close();
181        }
182        else {
183          fLogger.fine("Not serving html. Not modifiying response.");
184          aChain.doFilter(aRequest, aResponse); //do nothing special
185        }
186        fLogger.fine("END CSRF Filter.");
187      }
188      
189      /**
190       Add a CSRF token to an existing session <i>that has no user login</i>.
191       
192       <P><i>This method is called only when a session created by an Action, instead of the usual login mechanism.</i>
193       See {@link hirondelle.web4j.action.ActionImpl#createSessionAndCsrfToken()} for important information.  
194      */
195      public void addCsrfToken(HttpServletRequest aRequest) throws ServletException {
196        addItemsForNewSessions(aRequest);
197      }
198      
199      // PRIVATE
200      
201      //WARNING : Filters always need to be thread-safe !!
202      
203      private static final Logger fLogger = Util.getLogger(CsrfFilter.class);
204      private static final boolean DO_NOT_CREATE = false;
205      
206      private static void checkValidSqlId(String aSqlId){
207        if ( ! Util.textHasContent(aSqlId) ) {
208          String message = "SqlId required as Filter init-param, but has no content: " + Util.quote(aSqlId); 
209          fLogger.severe(message);
210        }
211      }
212      
213      private void addItemsForNewSessions(HttpServletRequest aRequest) throws ServletException {
214        HttpSession session = aRequest.getSession(DO_NOT_CREATE);
215        if ( sessionExists(session) ){
216          if ( hasNoFormSourceIdInSession(session) ){
217            Id currentFormSourceId = calcFormSourceId();
218            addFormSourceIdToSession(session, currentFormSourceId);
219            if( userHasLoggedIn(aRequest) ){
220              CsrfDAO formSourceDAO = new CsrfDAO(aRequest.getUserPrincipal().getName(), currentFormSourceId);
221              addPreviousFormSourceIdToSession(session, formSourceDAO);          
222              addFormSourceDAOToSession(session, formSourceDAO);
223            }
224          }
225        }
226      }
227      
228      private boolean sessionExists(HttpSession aSession){
229        return aSession != null;
230      }
231      
232      private boolean hasNoFormSourceIdInSession(HttpSession aSession){
233        return aSession.getAttribute(FORM_SOURCE_ID_KEY) == null;
234      }
235    
236      private boolean userHasLoggedIn(HttpServletRequest aRequest){
237        return aRequest.getUserPrincipal() != null;
238      }
239      
240      private void addFormSourceIdToSession(HttpSession aSession, Id aCurrentFormSourceId) {
241        fLogger.fine("Adding new form-source id to user's session.");
242        aSession.setAttribute(FORM_SOURCE_ID_KEY, aCurrentFormSourceId);
243      }
244      
245      private Id calcFormSourceId(){
246        String token = getHashFor( getRandomNumber().toString() );
247        return new Id(token);    
248      }
249      
250      private void addPreviousFormSourceIdToSession(HttpSession aSession, CsrfDAO aDAO) throws ServletException {
251        fLogger.fine("Adding previous form-source id to session.");
252        try {
253          Id previousFormSourceId = aDAO.fetchPreviousFormSourceId();
254          if( previousFormSourceId == null ) {
255            fLogger.fine("No previous form-source id found.");
256          }
257          else {
258            fLogger.fine("Adding previous form-source id to session.");
259            aSession.setAttribute(PREVIOUS_FORM_SOURCE_ID_KEY, previousFormSourceId);
260          }
261        }
262        catch (DAOException ex){
263          throw new ServletException("Cannot fetch previous form-source id from database.", ex);
264        }
265      }
266      
267      private void  addFormSourceDAOToSession(HttpSession aSession, CsrfDAO aDAO) {
268        fLogger.fine("Adding CsrfDAO object to session.");
269        aSession.setAttribute(FORM_SOURCE_DAO_KEY, aDAO);
270      }
271      
272      private synchronized Long getRandomNumber() {
273        Random random = new Random();
274        return random.nextLong();
275      }
276      
277      private String getHashFor(String aText) {
278        String result = null;
279        try {
280          MessageDigest sha = MessageDigest.getInstance("SHA-1");
281          byte[] hashOne = sha.digest(aText.getBytes());
282          result = hexEncode(hashOne);
283        }
284        catch (NoSuchAlgorithmException ex){
285          String message = "MessageDigest cannot find SHA-1 algorithm."; 
286          fLogger.severe(message);
287          throw new RuntimeException(message);
288        }
289        return result;
290      }
291      
292      /**
293       The byte[] returned by MessageDigest does not have a nice
294       textual representation, so some form of encoding is usually performed.
295      
296       This implementation follows the example of David Flanagan's book
297       "Java In A Nutshell", and converts a byte array into a String
298       of hex characters.
299      */
300      static private String hexEncode( byte[] aInput){
301        StringBuilder result = new StringBuilder();
302        char[] digits = {'0', '1', '2', '3', '4','5','6','7','8','9','a','b','c','d','e','f'};
303        for (int idx = 0; idx < aInput.length; ++idx) {
304          byte b = aInput[idx];
305          result.append( digits[ (b&0xf0) >> 4 ] );
306          result.append( digits[ b&0x0f] );
307        }
308        return result.toString();
309      }
310    
311      private static final String TEXT_HTML = "text/html";
312      
313      /** Return true if content-type of reponse is null, or starts with 'text/html' (case-sensitive).  */
314      private boolean isServingHtml(HttpServletResponse aResponse){
315        String contentType = aResponse.getContentType();
316        boolean missingContentType = ! Util.textHasContent(contentType);
317        boolean startsWithHTML = Util.textHasContent(contentType) && contentType.startsWith(TEXT_HTML);
318        return missingContentType || startsWithHTML;
319      }
320      
321      private static final class CharResponseWrapper extends HttpServletResponseWrapper {
322        public String toString() {
323            return fOutput.toString();
324        }
325        public CharResponseWrapper(HttpServletResponse response){
326            super(response);
327            fOutput = new CharArrayWriter();
328        }
329        public PrintWriter getWriter(){
330            return new PrintWriter(fOutput);
331        }
332        private CharArrayWriter fOutput;
333      }
334    }