Line data Source code
1 : // Copyright (C) 2015 The Android Open Source Project 2 : // 3 : // Licensed under the Apache License, Version 2.0 (the "License"); 4 : // you may not use this file except in compliance with the License. 5 : // You may obtain a copy of the License at 6 : // 7 : // http://www.apache.org/licenses/LICENSE-2.0 8 : // 9 : // Unless required by applicable law or agreed to in writing, software 10 : // distributed under the License is distributed on an "AS IS" BASIS, 11 : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 : // See the License for the specific language governing permissions and 13 : // limitations under the License. 14 : 15 : package com.google.gerrit.httpd.auth.oauth; 16 : 17 : import static java.nio.charset.StandardCharsets.UTF_8; 18 : 19 : import com.google.common.base.MoreObjects; 20 : import com.google.common.base.Strings; 21 : import com.google.common.collect.Iterables; 22 : import com.google.gerrit.common.Nullable; 23 : import com.google.gerrit.extensions.auth.oauth.OAuthServiceProvider; 24 : import com.google.gerrit.extensions.registration.DynamicMap; 25 : import com.google.gerrit.httpd.HtmlDomUtil; 26 : import com.google.gerrit.httpd.LoginUrlToken; 27 : import com.google.gerrit.httpd.template.SiteHeaderFooter; 28 : import com.google.gerrit.server.config.CanonicalWebUrl; 29 : import com.google.inject.Inject; 30 : import com.google.inject.Provider; 31 : import com.google.inject.Singleton; 32 : import java.io.IOException; 33 : import java.util.Map; 34 : import java.util.NavigableMap; 35 : import java.util.NavigableSet; 36 : import java.util.Set; 37 : import javax.servlet.Filter; 38 : import javax.servlet.FilterChain; 39 : import javax.servlet.FilterConfig; 40 : import javax.servlet.ServletException; 41 : import javax.servlet.ServletOutputStream; 42 : import javax.servlet.ServletRequest; 43 : import javax.servlet.ServletResponse; 44 : import javax.servlet.http.HttpServletRequest; 45 : import javax.servlet.http.HttpServletResponse; 46 : import org.w3c.dom.Document; 47 : import org.w3c.dom.Element; 48 : 49 : @Singleton 50 : /* OAuth web filter uses active OAuth session to perform OAuth requests */ 51 : class OAuthWebFilter implements Filter { 52 : static final String GERRIT_LOGIN = "/login"; 53 : 54 : private final Provider<String> urlProvider; 55 : private final Provider<OAuthSession> oauthSessionProvider; 56 : private final DynamicMap<OAuthServiceProvider> oauthServiceProviders; 57 : private final SiteHeaderFooter header; 58 : private OAuthServiceProvider ssoProvider; 59 : 60 : @Inject 61 : OAuthWebFilter( 62 : @CanonicalWebUrl @Nullable Provider<String> urlProvider, 63 : DynamicMap<OAuthServiceProvider> oauthServiceProviders, 64 : Provider<OAuthSession> oauthSessionProvider, 65 0 : SiteHeaderFooter header) { 66 0 : this.urlProvider = urlProvider; 67 0 : this.oauthServiceProviders = oauthServiceProviders; 68 0 : this.oauthSessionProvider = oauthSessionProvider; 69 0 : this.header = header; 70 0 : } 71 : 72 : @Override 73 : public void init(FilterConfig filterConfig) throws ServletException { 74 0 : pickSSOServiceProvider(); 75 0 : } 76 : 77 : @Override 78 0 : public void destroy() {} 79 : 80 : @Override 81 : public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) 82 : throws IOException, ServletException { 83 0 : HttpServletRequest httpRequest = (HttpServletRequest) request; 84 0 : HttpServletResponse httpResponse = (HttpServletResponse) response; 85 : 86 0 : OAuthSession oauthSession = oauthSessionProvider.get(); 87 0 : if (request.getParameter("link") != null) { 88 0 : oauthSession.setLinkMode(true); 89 0 : oauthSession.setServiceProvider(null); 90 : } 91 : 92 0 : String provider = httpRequest.getParameter("provider"); 93 : OAuthServiceProvider service = 94 0 : ssoProvider == null ? oauthSession.getServiceProvider() : ssoProvider; 95 : 96 0 : if (isGerritLogin(httpRequest) || oauthSession.isOAuthFinal(httpRequest)) { 97 0 : if (service == null && Strings.isNullOrEmpty(provider)) { 98 0 : selectProvider(httpRequest, httpResponse, null); 99 0 : return; 100 : } 101 0 : if (service == null) { 102 0 : service = findService(provider); 103 : } 104 0 : oauthSession.setServiceProvider(service); 105 0 : oauthSession.login(httpRequest, httpResponse, service); 106 : } else { 107 0 : chain.doFilter(httpRequest, response); 108 : } 109 0 : } 110 : 111 : private OAuthServiceProvider findService(String providerId) throws ServletException { 112 0 : Set<String> plugins = oauthServiceProviders.plugins(); 113 0 : for (String pluginName : plugins) { 114 0 : Map<String, Provider<OAuthServiceProvider>> m = oauthServiceProviders.byPlugin(pluginName); 115 0 : for (Map.Entry<String, Provider<OAuthServiceProvider>> e : m.entrySet()) { 116 0 : if (providerId.equals(String.format("%s_%s", pluginName, e.getKey()))) { 117 0 : return e.getValue().get(); 118 : } 119 0 : } 120 0 : } 121 0 : throw new ServletException("No provider found for: " + providerId); 122 : } 123 : 124 : private void selectProvider( 125 : HttpServletRequest req, HttpServletResponse res, @Nullable String errorMessage) 126 : throws IOException { 127 0 : String self = req.getRequestURI(); 128 0 : String cancel = MoreObjects.firstNonNull(urlProvider != null ? urlProvider.get() : "/", "/"); 129 0 : cancel += LoginUrlToken.getToken(req); 130 : 131 0 : Document doc = header.parse(OAuthWebFilter.class, "LoginForm.html"); 132 0 : HtmlDomUtil.find(doc, "hostName").setTextContent(req.getServerName()); 133 0 : HtmlDomUtil.find(doc, "login_form").setAttribute("action", self); 134 0 : HtmlDomUtil.find(doc, "cancel_link").setAttribute("href", cancel); 135 : 136 0 : Element emsg = HtmlDomUtil.find(doc, "error_message"); 137 0 : if (Strings.isNullOrEmpty(errorMessage)) { 138 0 : emsg.getParentNode().removeChild(emsg); 139 : } else { 140 0 : emsg.setTextContent(errorMessage); 141 : } 142 : 143 0 : Element providers = HtmlDomUtil.find(doc, "providers"); 144 : 145 0 : Set<String> plugins = oauthServiceProviders.plugins(); 146 0 : for (String pluginName : plugins) { 147 0 : Map<String, Provider<OAuthServiceProvider>> m = oauthServiceProviders.byPlugin(pluginName); 148 0 : for (Map.Entry<String, Provider<OAuthServiceProvider>> e : m.entrySet()) { 149 0 : addProvider(providers, pluginName, e.getKey(), e.getValue().get().getName()); 150 0 : } 151 0 : } 152 : 153 0 : sendHtml(res, doc); 154 0 : } 155 : 156 : private static void addProvider(Element form, String pluginName, String id, String serviceName) { 157 0 : Element div = form.getOwnerDocument().createElement("div"); 158 0 : div.setAttribute("id", id); 159 0 : Element hyperlink = form.getOwnerDocument().createElement("a"); 160 0 : hyperlink.setAttribute("href", String.format("?provider=%s_%s", pluginName, id)); 161 0 : hyperlink.setTextContent(serviceName + " (" + pluginName + " plugin)"); 162 0 : div.appendChild(hyperlink); 163 0 : form.appendChild(div); 164 0 : } 165 : 166 : private static void sendHtml(HttpServletResponse res, Document doc) throws IOException { 167 0 : byte[] bin = HtmlDomUtil.toUTF8(doc); 168 0 : res.setStatus(HttpServletResponse.SC_UNAUTHORIZED); 169 0 : res.setContentType("text/html"); 170 0 : res.setCharacterEncoding(UTF_8.name()); 171 0 : res.setContentLength(bin.length); 172 0 : try (ServletOutputStream out = res.getOutputStream()) { 173 0 : out.write(bin); 174 : } 175 0 : } 176 : 177 : private void pickSSOServiceProvider() throws ServletException { 178 0 : NavigableSet<String> plugins = oauthServiceProviders.plugins(); 179 0 : if (plugins.isEmpty()) { 180 0 : throw new ServletException("OAuth service provider wasn't installed"); 181 : } 182 0 : if (plugins.size() == 1) { 183 0 : NavigableMap<String, Provider<OAuthServiceProvider>> services = 184 0 : oauthServiceProviders.byPlugin(Iterables.getOnlyElement(plugins)); 185 0 : if (services.size() == 1) { 186 0 : ssoProvider = Iterables.getOnlyElement(services.values()).get(); 187 : } 188 : } 189 0 : } 190 : 191 : private static boolean isGerritLogin(HttpServletRequest request) { 192 0 : return request.getRequestURI().contains(GERRIT_LOGIN); 193 : } 194 : }