Major Enhancements to SAML2 and OAuth2 Integration with Simplified Security Configurations (#2040)

* implement Saml2 login/logout

* changed: deprecation code

* relyingPartyRegistrations only enabled samle
This commit is contained in:
Ludy 2024-10-20 13:30:58 +02:00 committed by GitHub
parent 227d18a469
commit eff1843061
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 1080 additions and 839 deletions

View file

@ -32,11 +32,9 @@ java {
repositories { repositories {
mavenCentral() mavenCentral()
maven { url "https://jitpack.io" } maven { url "https://jitpack.io" }
maven { url "https://build.shibboleth.net/nexus/content/repositories/releases/" }
maven { maven {
url "https://build.shibboleth.net/nexus/content/repositories/releases/" url 'https://build.shibboleth.net/maven/releases'
}
maven {
url "https://build.shibboleth.net/maven/releases/"
} }
} }
@ -148,6 +146,14 @@ dependencies {
//2.2.x requires rebuild of DB file.. need migration path //2.2.x requires rebuild of DB file.. need migration path
runtimeOnly "com.h2database:h2:2.1.214" runtimeOnly "com.h2database:h2:2.1.214"
// implementation "com.h2database:h2:2.2.224" // implementation "com.h2database:h2:2.2.224"
constraints {
implementation "org.opensaml:opensaml-core"
implementation "org.opensaml:opensaml-saml-api"
implementation "org.opensaml:opensaml-saml-impl"
}
implementation "org.springframework.security:spring-security-saml2-service-provider"
implementation 'com.coveo:saml-client:5.0.0'
} }
testImplementation "org.springframework.boot:spring-boot-starter-test:$springBootVersion" testImplementation "org.springframework.boot:spring-boot-starter-test:$springBootVersion"

View file

@ -33,8 +33,12 @@ public class SPdfApplication {
@Autowired private Environment env; @Autowired private Environment env;
@Autowired ApplicationProperties applicationProperties; @Autowired ApplicationProperties applicationProperties;
private static String baseUrlStatic;
private static String serverPortStatic; private static String serverPortStatic;
@Value("${baseUrl:http://localhost}")
private String baseUrl;
@Value("${server.port:8080}") @Value("${server.port:8080}")
public void setServerPortStatic(String port) { public void setServerPortStatic(String port) {
if ("auto".equalsIgnoreCase(port)) { if ("auto".equalsIgnoreCase(port)) {
@ -65,12 +69,13 @@ public class SPdfApplication {
@PostConstruct @PostConstruct
public void init() { public void init() {
baseUrlStatic = this.baseUrl;
// Check if the BROWSER_OPEN environment variable is set to true // Check if the BROWSER_OPEN environment variable is set to true
String browserOpenEnv = env.getProperty("BROWSER_OPEN"); String browserOpenEnv = env.getProperty("BROWSER_OPEN");
boolean browserOpen = browserOpenEnv != null && "true".equalsIgnoreCase(browserOpenEnv); boolean browserOpen = browserOpenEnv != null && "true".equalsIgnoreCase(browserOpenEnv);
if (browserOpen) { if (browserOpen) {
try { try {
String url = "http://localhost:" + getStaticPort(); String url = baseUrl + ":" + getStaticPort();
String os = System.getProperty("os.name").toLowerCase(); String os = System.getProperty("os.name").toLowerCase();
Runtime rt = Runtime.getRuntime(); Runtime rt = Runtime.getRuntime();
@ -138,10 +143,18 @@ public class SPdfApplication {
private static void printStartupLogs() { private static void printStartupLogs() {
logger.info("Stirling-PDF Started."); logger.info("Stirling-PDF Started.");
String url = "http://localhost:" + getStaticPort(); String url = baseUrlStatic + ":" + getStaticPort();
logger.info("Navigate to {}", url); logger.info("Navigate to {}", url);
} }
public static String getStaticBaseUrl() {
return baseUrlStatic;
}
public String getNonStaticBaseUrl() {
return baseUrlStatic;
}
public static String getStaticPort() { public static String getStaticPort() {
return serverPortStatic; return serverPortStatic;
} }

View file

@ -1,27 +1,237 @@
package stirling.software.SPDF.config.security; package stirling.software.SPDF.config.security;
import java.io.IOException; import java.io.IOException;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPrivateKey;
import java.util.ArrayList;
import java.util.List;
import org.springframework.core.io.Resource;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler; import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler;
import com.coveo.saml.SamlClient;
import jakarta.servlet.ServletException; import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.SPdfApplication;
import stirling.software.SPDF.config.security.saml2.CertificateUtils;
import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrincipal;
import stirling.software.SPDF.model.ApplicationProperties;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2;
import stirling.software.SPDF.model.ApplicationProperties.Security.SAML2;
import stirling.software.SPDF.model.Provider;
import stirling.software.SPDF.model.provider.UnsupportedProviderException;
import stirling.software.SPDF.utils.UrlUtils;
@Slf4j
@AllArgsConstructor
public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler { public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
private final ApplicationProperties applicationProperties;
@Override @Override
public void onLogoutSuccess( public void onLogoutSuccess(
HttpServletRequest request, HttpServletResponse response, Authentication authentication) HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException, ServletException { throws IOException, ServletException {
if (!response.isCommitted()) {
// Handle user logout due to disabled account
if (request.getParameter("userIsDisabled") != null) { if (request.getParameter("userIsDisabled") != null) {
getRedirectStrategy() response.sendRedirect(
.sendRedirect(request, response, "/login?erroroauth=userIsDisabled"); request.getContextPath() + "/login?erroroauth=userIsDisabled");
return; return;
} }
// Handle OAuth2 authentication error
if (request.getParameter("oauth2AuthenticationErrorWeb") != null) {
response.sendRedirect(
request.getContextPath() + "/login?erroroauth=userAlreadyExistsWeb");
return;
}
if (authentication != null) {
// Handle SAML2 logout redirection
if (authentication instanceof Saml2Authentication) {
getRedirect_saml2(request, response, authentication);
return;
}
// Handle OAuth2 logout redirection
else if (authentication instanceof OAuth2AuthenticationToken) {
getRedirect_oauth2(request, response, authentication);
return;
}
// Handle Username/Password logout
else if (authentication instanceof UsernamePasswordAuthenticationToken) {
getRedirectStrategy().sendRedirect(request, response, "/login?logout=true");
return;
}
// Handle unknown authentication types
else {
log.error(
"authentication class unknown: "
+ authentication.getClass().getSimpleName());
getRedirectStrategy().sendRedirect(request, response, "/login?logout=true");
return;
}
} else {
// Redirect to login page after logout
getRedirectStrategy().sendRedirect(request, response, "/login?logout=true");
return;
}
}
}
// Redirect for SAML2 authentication logout
private void getRedirect_saml2(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException {
SAML2 samlConf = applicationProperties.getSecurity().getSaml2();
String registrationId = samlConf.getRegistrationId();
Saml2Authentication samlAuthentication = (Saml2Authentication) authentication;
CustomSaml2AuthenticatedPrincipal principal =
(CustomSaml2AuthenticatedPrincipal) samlAuthentication.getPrincipal();
String nameIdValue = principal.getName();
try {
// Read certificate from the resource
Resource certificateResource = samlConf.getSpCert();
X509Certificate certificate = CertificateUtils.readCertificate(certificateResource);
List<X509Certificate> certificates = new ArrayList<>();
certificates.add(certificate);
// Construct URLs required for SAML configuration
String serverUrl =
SPdfApplication.getStaticBaseUrl() + ":" + SPdfApplication.getStaticPort();
String relyingPartyIdentifier =
serverUrl + "/saml2/service-provider-metadata/" + registrationId;
String assertionConsumerServiceUrl = serverUrl + "/login/saml2/sso/" + registrationId;
String idpUrl = samlConf.getIdpSingleLogoutUrl();
String idpIssuer = samlConf.getIdpIssuer();
// Create SamlClient instance for SAML logout
SamlClient samlClient =
new SamlClient(
relyingPartyIdentifier,
assertionConsumerServiceUrl,
idpUrl,
idpIssuer,
certificates,
SamlClient.SamlIdpBinding.POST);
// Read private key for service provider
Resource privateKeyResource = samlConf.getPrivateKey();
RSAPrivateKey privateKey = CertificateUtils.readPrivateKey(privateKeyResource);
// Set service provider keys for the SamlClient
samlClient.setSPKeys(certificate, privateKey);
// Redirect to identity provider for logout
samlClient.redirectToIdentityProvider(response, null, nameIdValue);
} catch (Exception e) {
log.error(nameIdValue, e);
getRedirectStrategy().sendRedirect(request, response, "/login?logout=true"); getRedirectStrategy().sendRedirect(request, response, "/login?logout=true");
} }
} }
// Redirect for OAuth2 authentication logout
private void getRedirect_oauth2(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException {
String param = "logout=true";
String registrationId = null;
String issuer = null;
String clientId = null;
OAUTH2 oauth = applicationProperties.getSecurity().getOauth2();
if (authentication instanceof OAuth2AuthenticationToken) {
OAuth2AuthenticationToken oauthToken = (OAuth2AuthenticationToken) authentication;
registrationId = oauthToken.getAuthorizedClientRegistrationId();
try {
// Get OAuth2 provider details from configuration
Provider provider = oauth.getClient().get(registrationId);
issuer = provider.getIssuer();
clientId = provider.getClientId();
} catch (UnsupportedProviderException e) {
log.error(e.getMessage());
}
} else {
registrationId = oauth.getProvider() != null ? oauth.getProvider() : "";
issuer = oauth.getIssuer();
clientId = oauth.getClientId();
}
String errorMessage = "";
// Handle different error scenarios during logout
if (request.getParameter("oauth2AuthenticationErrorWeb") != null) {
param = "erroroauth=oauth2AuthenticationErrorWeb";
} else if ((errorMessage = request.getParameter("error")) != null) {
param = "error=" + sanitizeInput(errorMessage);
} else if ((errorMessage = request.getParameter("erroroauth")) != null) {
param = "erroroauth=" + sanitizeInput(errorMessage);
} else if (request.getParameter("oauth2AutoCreateDisabled") != null) {
param = "error=oauth2AutoCreateDisabled";
} else if (request.getParameter("oauth2_admin_blocked_user") != null) {
param = "erroroauth=oauth2_admin_blocked_user";
} else if (request.getParameter("userIsDisabled") != null) {
param = "erroroauth=userIsDisabled";
} else if (request.getParameter("badcredentials") != null) {
param = "error=badcredentials";
}
String redirect_url = UrlUtils.getOrigin(request) + "/login?" + param;
// Redirect based on OAuth2 provider
switch (registrationId.toLowerCase()) {
case "keycloak":
// Add Keycloak specific logout URL if needed
String logoutUrl =
issuer
+ "/protocol/openid-connect/logout"
+ "?client_id="
+ clientId
+ "&post_logout_redirect_uri="
+ response.encodeRedirectURL(redirect_url);
log.info("Redirecting to Keycloak logout URL: " + logoutUrl);
response.sendRedirect(logoutUrl);
break;
case "github":
// Add GitHub specific logout URL if needed
String githubLogoutUrl = "https://github.com/logout";
log.info("Redirecting to GitHub logout URL: " + githubLogoutUrl);
response.sendRedirect(githubLogoutUrl);
break;
case "google":
// Add Google specific logout URL if needed
// String googleLogoutUrl =
// "https://accounts.google.com/Logout?continue=https://appengine.google.com/_ah/logout?continue="
// + response.encodeRedirectURL(redirect_url);
log.info("Google does not have a specific logout URL");
// log.info("Redirecting to Google logout URL: " + googleLogoutUrl);
// response.sendRedirect(googleLogoutUrl);
// break;
default:
String defaultRedirectUrl = request.getContextPath() + "/login?" + param;
log.info("Redirecting to default logout URL: " + defaultRedirectUrl);
response.sendRedirect(defaultRedirectUrl);
break;
}
}
// Sanitize input to avoid potential security vulnerabilities
private String sanitizeInput(String input) {
return input.replaceAll("[^a-zA-Z0-9 ]", "");
}
}

View file

@ -1,22 +1,36 @@
package stirling.software.SPDF.config.security; package stirling.software.SPDF.config.security;
import java.security.cert.X509Certificate;
import java.util.*;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Lazy; import org.springframework.context.annotation.Lazy;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.core.io.Resource;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.authentication.dao.DaoAuthenticationProvider;
import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity; import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.http.SessionCreationPolicy; import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.ClientRegistrations;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.authentication.Saml2WebSsoAuthenticationFilter; import org.springframework.security.saml2.provider.service.web.authentication.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.SecurityFilterChain;
@ -28,13 +42,20 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.config.security.oauth2.CustomOAuth2AuthenticationFailureHandler; import stirling.software.SPDF.config.security.oauth2.CustomOAuth2AuthenticationFailureHandler;
import stirling.software.SPDF.config.security.oauth2.CustomOAuth2AuthenticationSuccessHandler; import stirling.software.SPDF.config.security.oauth2.CustomOAuth2AuthenticationSuccessHandler;
import stirling.software.SPDF.config.security.oauth2.CustomOAuth2LogoutSuccessHandler;
import stirling.software.SPDF.config.security.oauth2.CustomOAuth2UserService; import stirling.software.SPDF.config.security.oauth2.CustomOAuth2UserService;
import stirling.software.SPDF.config.security.saml.ConvertResponseToAuthentication; import stirling.software.SPDF.config.security.saml2.CertificateUtils;
import stirling.software.SPDF.config.security.saml.CustomSAMLAuthenticationFailureHandler; import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticationFailureHandler;
import stirling.software.SPDF.config.security.saml.CustomSAMLAuthenticationSuccessHandler; import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticationSuccessHandler;
import stirling.software.SPDF.config.security.saml2.CustomSaml2ResponseAuthenticationConverter;
import stirling.software.SPDF.config.security.session.SessionPersistentRegistry; import stirling.software.SPDF.config.security.session.SessionPersistentRegistry;
import stirling.software.SPDF.model.ApplicationProperties; import stirling.software.SPDF.model.ApplicationProperties;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2.Client;
import stirling.software.SPDF.model.ApplicationProperties.Security.SAML2;
import stirling.software.SPDF.model.User;
import stirling.software.SPDF.model.provider.GithubProvider;
import stirling.software.SPDF.model.provider.GoogleProvider;
import stirling.software.SPDF.model.provider.KeycloakProvider;
import stirling.software.SPDF.repository.JPATokenRepositoryImpl; import stirling.software.SPDF.repository.JPATokenRepositoryImpl;
@Configuration @Configuration
@ -45,12 +66,6 @@ public class SecurityConfiguration {
@Autowired private CustomUserDetailsService userDetailsService; @Autowired private CustomUserDetailsService userDetailsService;
@Autowired(required = false)
private GrantedAuthoritiesMapper userAuthoritiesMapper;
@Autowired(required = false)
private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
@Bean @Bean
public PasswordEncoder passwordEncoder() { public PasswordEncoder passwordEncoder() {
return new BCryptPasswordEncoder(); return new BCryptPasswordEncoder();
@ -71,11 +86,8 @@ public class SecurityConfiguration {
@Autowired private FirstLoginFilter firstLoginFilter; @Autowired private FirstLoginFilter firstLoginFilter;
@Autowired private SessionPersistentRegistry sessionRegistry; @Autowired private SessionPersistentRegistry sessionRegistry;
@Autowired private ConvertResponseToAuthentication convertResponseToAuthentication;
@Bean @Bean
public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
http.authenticationManager(authenticationManager(http));
if (loginEnabledValue) { if (loginEnabledValue) {
http.addFilterBefore( http.addFilterBefore(
@ -94,34 +106,23 @@ public class SecurityConfiguration {
.sessionRegistry(sessionRegistry) .sessionRegistry(sessionRegistry)
.expiredUrl("/login?logout=true")); .expiredUrl("/login?logout=true"));
http.formLogin( http.authenticationProvider(daoAuthenticationProvider());
formLogin -> http.requestCache(requestCache -> requestCache.requestCache(new NullRequestCache()));
formLogin http.logout(
.loginPage("/login")
.successHandler(
new CustomAuthenticationSuccessHandler(
loginAttemptService, userService))
.defaultSuccessUrl("/")
.failureHandler(
new CustomAuthenticationFailureHandler(
loginAttemptService, userService))
.permitAll())
.requestCache(requestCache -> requestCache.requestCache(new NullRequestCache()))
.logout(
logout -> logout ->
logout.logoutRequestMatcher( logout.logoutRequestMatcher(new AntPathRequestMatcher("/logout"))
new AntPathRequestMatcher("/logout")) .logoutSuccessHandler(
.logoutSuccessHandler(new CustomLogoutSuccessHandler()) new CustomLogoutSuccessHandler(applicationProperties))
.invalidateHttpSession(true) // Invalidate session .invalidateHttpSession(true) // Invalidate session
.deleteCookies("JSESSIONID", "remember-me")) .deleteCookies("JSESSIONID", "remember-me"));
.rememberMe( http.rememberMe(
rememberMeConfigurer -> rememberMeConfigurer ->
rememberMeConfigurer // Use the configurator directly rememberMeConfigurer // Use the configurator directly
.key("uniqueAndSecret") .key("uniqueAndSecret")
.tokenRepository(persistentTokenRepository()) .tokenRepository(persistentTokenRepository())
.tokenValiditySeconds(1209600) // 2 weeks .tokenValiditySeconds(1209600) // 2 weeks
) );
.authorizeHttpRequests( http.authorizeHttpRequests(
authz -> authz ->
authz.requestMatchers( authz.requestMatchers(
req -> { req -> {
@ -132,16 +133,14 @@ public class SecurityConfiguration {
String trimmedUri = String trimmedUri =
uri.startsWith(contextPath) uri.startsWith(contextPath)
? uri.substring( ? uri.substring(
contextPath contextPath.length())
.length())
: uri; : uri;
return trimmedUri.startsWith("/login") return trimmedUri.startsWith("/login")
|| trimmedUri.startsWith("/oauth") || trimmedUri.startsWith("/oauth")
|| trimmedUri.startsWith("/saml2") || trimmedUri.startsWith("/saml2")
|| trimmedUri.endsWith(".svg") || trimmedUri.endsWith(".svg")
|| trimmedUri.startsWith( || trimmedUri.startsWith("/register")
"/register")
|| trimmedUri.startsWith("/error") || trimmedUri.startsWith("/error")
|| trimmedUri.startsWith("/images/") || trimmedUri.startsWith("/images/")
|| trimmedUri.startsWith("/public/") || trimmedUri.startsWith("/public/")
@ -155,13 +154,24 @@ public class SecurityConfiguration {
.anyRequest() .anyRequest()
.authenticated()); .authenticated());
// Handle User/Password Logins
if (applicationProperties.getSecurity().isUserPass()) {
http.formLogin(
formLogin ->
formLogin
.loginPage("/login")
.successHandler(
new CustomAuthenticationSuccessHandler(
loginAttemptService, userService))
.failureHandler(
new CustomAuthenticationFailureHandler(
loginAttemptService, userService))
.defaultSuccessUrl("/")
.permitAll());
}
// Handle OAUTH2 Logins // Handle OAUTH2 Logins
if (applicationProperties.getSecurity().getOauth2() != null if (applicationProperties.getSecurity().isOauth2Activ()) {
&& applicationProperties.getSecurity().getOauth2().getEnabled()
&& !applicationProperties
.getSecurity()
.getLoginMethod()
.equalsIgnoreCase("normal")) {
http.oauth2Login( http.oauth2Login(
oauth2 -> oauth2 ->
@ -188,34 +198,24 @@ public class SecurityConfiguration {
userService, userService,
loginAttemptService)) loginAttemptService))
.userAuthoritiesMapper( .userAuthoritiesMapper(
userAuthoritiesMapper))) userAuthoritiesMapper()))
.logout( .permitAll());
logout ->
logout.logoutSuccessHandler(
new CustomOAuth2LogoutSuccessHandler(
applicationProperties)));
} }
// Handle SAML // Handle SAML
if (applicationProperties.getSecurity().getSaml() != null if (applicationProperties.getSecurity().isSaml2Activ()) {
&& applicationProperties.getSecurity().getSaml().getEnabled() http.authenticationProvider(samlAuthenticationProvider());
&& !applicationProperties
.getSecurity()
.getLoginMethod()
.equalsIgnoreCase("normal")) {
http.saml2Login( http.saml2Login(
saml2 -> { saml2 ->
saml2.loginPage("/saml2") saml2.loginPage("/saml2")
.relyingPartyRegistrationRepository(
relyingPartyRegistrationRepository)
.successHandler( .successHandler(
new CustomSAMLAuthenticationSuccessHandler( new CustomSaml2AuthenticationSuccessHandler(
loginAttemptService, loginAttemptService,
userService, applicationProperties,
applicationProperties)) userService))
.failureHandler( .failureHandler(
new CustomSAMLAuthenticationFailureHandler()); new CustomSaml2AuthenticationFailureHandler())
}) .permitAll())
.addFilterBefore( .addFilterBefore(
userAuthenticationFilter, Saml2WebSsoAuthenticationFilter.class); userAuthenticationFilter, Saml2WebSsoAuthenticationFilter.class);
} }
@ -231,39 +231,234 @@ public class SecurityConfiguration {
@Bean @Bean
@ConditionalOnProperty( @ConditionalOnProperty(
name = "security.saml.enabled", name = "security.saml2.enabled",
havingValue = "true", havingValue = "true",
matchIfMissing = false) matchIfMissing = false)
public AuthenticationProvider samlAuthenticationProvider() { public AuthenticationProvider samlAuthenticationProvider() {
OpenSaml4AuthenticationProvider authenticationProvider = OpenSaml4AuthenticationProvider authenticationProvider =
new OpenSaml4AuthenticationProvider(); new OpenSaml4AuthenticationProvider();
authenticationProvider.setResponseAuthenticationConverter(convertResponseToAuthentication); authenticationProvider.setResponseAuthenticationConverter(
new CustomSaml2ResponseAuthenticationConverter(userService));
return authenticationProvider; return authenticationProvider;
} }
// @Bean // Client Registration Repository for OAUTH2 OIDC Login
// public AuthenticationProvider daoAuthenticationProvider() { @Bean
// DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); @ConditionalOnProperty(
// provider.setUserDetailsService(userDetailsService); // UserDetailsService value = "security.oauth2.enabled",
// provider.setPasswordEncoder(passwordEncoder()); // PasswordEncoder havingValue = "true",
// return provider; matchIfMissing = false)
// } public ClientRegistrationRepository clientRegistrationRepository() {
List<ClientRegistration> registrations = new ArrayList<>();
githubClientRegistration().ifPresent(registrations::add);
oidcClientRegistration().ifPresent(registrations::add);
googleClientRegistration().ifPresent(registrations::add);
keycloakClientRegistration().ifPresent(registrations::add);
if (registrations.isEmpty()) {
log.error("At least one OAuth2 provider must be configured");
System.exit(1);
}
return new InMemoryClientRegistrationRepository(registrations);
}
private Optional<ClientRegistration> googleClientRegistration() {
OAUTH2 oauth = applicationProperties.getSecurity().getOauth2();
if (oauth == null || !oauth.getEnabled()) {
return Optional.empty();
}
Client client = oauth.getClient();
if (client == null) {
return Optional.empty();
}
GoogleProvider google = client.getGoogle();
return google != null && google.isSettingsValid()
? Optional.of(
ClientRegistration.withRegistrationId(google.getName())
.clientId(google.getClientId())
.clientSecret(google.getClientSecret())
.scope(google.getScopes())
.authorizationUri(google.getAuthorizationuri())
.tokenUri(google.getTokenuri())
.userInfoUri(google.getUserinfouri())
.userNameAttributeName(google.getUseAsUsername())
.clientName(google.getClientName())
.redirectUri("{baseUrl}/login/oauth2/code/" + google.getName())
.authorizationGrantType(
org.springframework.security.oauth2.core
.AuthorizationGrantType.AUTHORIZATION_CODE)
.build())
: Optional.empty();
}
private Optional<ClientRegistration> keycloakClientRegistration() {
OAUTH2 oauth = applicationProperties.getSecurity().getOauth2();
if (oauth == null || !oauth.getEnabled()) {
return Optional.empty();
}
Client client = oauth.getClient();
if (client == null) {
return Optional.empty();
}
KeycloakProvider keycloak = client.getKeycloak();
return keycloak != null && keycloak.isSettingsValid()
? Optional.of(
ClientRegistrations.fromIssuerLocation(keycloak.getIssuer())
.registrationId(keycloak.getName())
.clientId(keycloak.getClientId())
.clientSecret(keycloak.getClientSecret())
.scope(keycloak.getScopes())
.userNameAttributeName(keycloak.getUseAsUsername())
.clientName(keycloak.getClientName())
.build())
: Optional.empty();
}
private Optional<ClientRegistration> githubClientRegistration() {
OAUTH2 oauth = applicationProperties.getSecurity().getOauth2();
if (oauth == null || !oauth.getEnabled()) {
return Optional.empty();
}
Client client = oauth.getClient();
if (client == null) {
return Optional.empty();
}
GithubProvider github = client.getGithub();
return github != null && github.isSettingsValid()
? Optional.of(
ClientRegistration.withRegistrationId(github.getName())
.clientId(github.getClientId())
.clientSecret(github.getClientSecret())
.scope(github.getScopes())
.authorizationUri(github.getAuthorizationuri())
.tokenUri(github.getTokenuri())
.userInfoUri(github.getUserinfouri())
.userNameAttributeName(github.getUseAsUsername())
.clientName(github.getClientName())
.redirectUri("{baseUrl}/login/oauth2/code/" + github.getName())
.authorizationGrantType(
org.springframework.security.oauth2.core
.AuthorizationGrantType.AUTHORIZATION_CODE)
.build())
: Optional.empty();
}
private Optional<ClientRegistration> oidcClientRegistration() {
OAUTH2 oauth = applicationProperties.getSecurity().getOauth2();
if (oauth == null
|| oauth.getIssuer() == null
|| oauth.getIssuer().isEmpty()
|| oauth.getClientId() == null
|| oauth.getClientId().isEmpty()
|| oauth.getClientSecret() == null
|| oauth.getClientSecret().isEmpty()
|| oauth.getScopes() == null
|| oauth.getScopes().isEmpty()
|| oauth.getUseAsUsername() == null
|| oauth.getUseAsUsername().isEmpty()) {
return Optional.empty();
}
return Optional.of(
ClientRegistrations.fromIssuerLocation(oauth.getIssuer())
.registrationId("oidc")
.clientId(oauth.getClientId())
.clientSecret(oauth.getClientSecret())
.scope(oauth.getScopes())
.userNameAttributeName(oauth.getUseAsUsername())
.clientName("OIDC")
.build());
}
@Bean @Bean
public AuthenticationManager authenticationManager(HttpSecurity http) throws Exception { @ConditionalOnProperty(
AuthenticationManagerBuilder authenticationManagerBuilder = name = "security.saml2.enabled",
http.getSharedObject(AuthenticationManagerBuilder.class); havingValue = "true",
matchIfMissing = false)
public RelyingPartyRegistrationRepository relyingPartyRegistrations() throws Exception {
// authenticationManagerBuilder = SAML2 samlConf = applicationProperties.getSecurity().getSaml2();
// authenticationManagerBuilder.authenticationProvider(
// daoAuthenticationProvider()); // Benutzername/Passwort
if (applicationProperties.getSecurity().getSaml() != null Resource privateKeyResource = samlConf.getPrivateKey();
&& applicationProperties.getSecurity().getSaml().getEnabled()) {
authenticationManagerBuilder.authenticationProvider( Resource certificateResource = samlConf.getSpCert();
samlAuthenticationProvider()); // SAML
Saml2X509Credential signingCredential =
new Saml2X509Credential(
CertificateUtils.readPrivateKey(privateKeyResource),
CertificateUtils.readCertificate(certificateResource),
Saml2X509CredentialType.SIGNING);
X509Certificate idpCert = CertificateUtils.readCertificate(samlConf.getidpCert());
Saml2X509Credential verificationCredential = Saml2X509Credential.verification(idpCert);
RelyingPartyRegistration rp =
RelyingPartyRegistration.withRegistrationId(samlConf.getRegistrationId())
.signingX509Credentials((c) -> c.add(signingCredential))
.assertingPartyDetails(
(details) ->
details.entityId(samlConf.getIdpIssuer())
.singleSignOnServiceLocation(
samlConf.getIdpSingleLoginUrl())
.verificationX509Credentials(
(c) -> c.add(verificationCredential))
.wantAuthnRequestsSigned(true))
.build();
return new InMemoryRelyingPartyRegistrationRepository(rp);
} }
return authenticationManagerBuilder.build();
@Bean
public DaoAuthenticationProvider daoAuthenticationProvider() {
DaoAuthenticationProvider provider = new DaoAuthenticationProvider();
provider.setUserDetailsService(userDetailsService);
provider.setPasswordEncoder(passwordEncoder());
return provider;
}
/*
This following function is to grant Authorities to the OAUTH2 user from the values stored in the database.
This is required for the internal; 'hasRole()' function to give out the correct role.
*/
@Bean
@ConditionalOnProperty(
value = "security.oauth2.enabled",
havingValue = "true",
matchIfMissing = false)
GrantedAuthoritiesMapper userAuthoritiesMapper() {
return (authorities) -> {
Set<GrantedAuthority> mappedAuthorities = new HashSet<>();
authorities.forEach(
authority -> {
// Add existing OAUTH2 Authorities
mappedAuthorities.add(new SimpleGrantedAuthority(authority.getAuthority()));
// Add Authorities from database for existing user, if user is present.
if (authority instanceof OAuth2UserAuthority oauth2Auth) {
String useAsUsername =
applicationProperties
.getSecurity()
.getOauth2()
.getUseAsUsername();
Optional<User> userOpt =
userService.findByUsernameIgnoreCase(
(String) oauth2Auth.getAttributes().get(useAsUsername));
if (userOpt.isPresent()) {
User user = userOpt.get();
if (user != null) {
mappedAuthorities.add(
new SimpleGrantedAuthority(
userService.findRole(user).getAuthority()));
}
}
}
});
return mappedAuthorities;
};
} }
@Bean @Bean

View file

@ -22,6 +22,7 @@ import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException; import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrincipal;
import stirling.software.SPDF.config.security.session.SessionPersistentRegistry; import stirling.software.SPDF.config.security.session.SessionPersistentRegistry;
import stirling.software.SPDF.model.ApiKeyAuthenticationToken; import stirling.software.SPDF.model.ApiKeyAuthenticationToken;
import stirling.software.SPDF.model.User; import stirling.software.SPDF.model.User;
@ -111,7 +112,9 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
response.setStatus(HttpStatus.UNAUTHORIZED.value()); response.setStatus(HttpStatus.UNAUTHORIZED.value());
response.getWriter() response.getWriter()
.write( .write(
"Authentication required. Please provide a X-API-KEY in request header.\nThis is found in Settings -> Account Settings -> API Key\nAlternatively you can disable authentication if this is unexpected"); "Authentication required. Please provide a X-API-KEY in request header.\n"
+ "This is found in Settings -> Account Settings -> API Key\n"
+ "Alternatively you can disable authentication if this is unexpected");
return; return;
} }
} }
@ -124,6 +127,8 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
username = ((UserDetails) principal).getUsername(); username = ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) { } else if (principal instanceof OAuth2User) {
username = ((OAuth2User) principal).getName(); username = ((OAuth2User) principal).getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
username = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
} else if (principal instanceof String) { } else if (principal instanceof String) {
username = (String) principal; username = (String) principal;
} }

View file

@ -20,6 +20,7 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import stirling.software.SPDF.config.interfaces.DatabaseBackupInterface; import stirling.software.SPDF.config.interfaces.DatabaseBackupInterface;
import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrincipal;
import stirling.software.SPDF.config.security.session.SessionPersistentRegistry; import stirling.software.SPDF.config.security.session.SessionPersistentRegistry;
import stirling.software.SPDF.controller.api.pipeline.UserServiceInterface; import stirling.software.SPDF.controller.api.pipeline.UserServiceInterface;
import stirling.software.SPDF.model.AuthenticationType; import stirling.software.SPDF.model.AuthenticationType;
@ -338,6 +339,10 @@ public class UserService implements UserServiceInterface {
} else if (principal instanceof OAuth2User) { } else if (principal instanceof OAuth2User) {
OAuth2User oAuth2User = (OAuth2User) principal; OAuth2User oAuth2User = (OAuth2User) principal;
usernameP = oAuth2User.getName(); usernameP = oAuth2User.getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
CustomSaml2AuthenticatedPrincipal saml2User =
(CustomSaml2AuthenticatedPrincipal) principal;
usernameP = saml2User.getName();
} else if (principal instanceof String) { } else if (principal instanceof String) {
usernameP = (String) principal; usernameP = (String) principal;
} }

View file

@ -51,8 +51,7 @@ public class CustomOAuth2AuthenticationFailureHandler
} }
log.error("OAuth2 Authentication error: " + errorCode); log.error("OAuth2 Authentication error: " + errorCode);
log.error("OAuth2AuthenticationException", exception); log.error("OAuth2AuthenticationException", exception);
getRedirectStrategy() getRedirectStrategy().sendRedirect(request, response, "/login?erroroauth=" + errorCode);
.sendRedirect(request, response, "/logout?erroroauth=" + errorCode);
return; return;
} }
log.error("Unhandled authentication exception", exception); log.error("Unhandled authentication exception", exception);

View file

@ -75,6 +75,11 @@ public class CustomOAuth2AuthenticationSuccessHandler
throw new LockedException( throw new LockedException(
"Your account has been locked due to too many failed login attempts."); "Your account has been locked due to too many failed login attempts.");
} }
if (userService.isUserDisabled(username)) {
getRedirectStrategy()
.sendRedirect(request, response, "/logout?userIsDisabled=true");
return;
}
if (userService.usernameExistsIgnoreCase(username) if (userService.usernameExistsIgnoreCase(username)
&& userService.hasPassword(username) && userService.hasPassword(username)
&& !userService.isAuthenticationTypeByUsername( && !userService.isAuthenticationTypeByUsername(

View file

@ -1,122 +0,0 @@
package stirling.software.SPDF.config.security.oauth2;
import java.io.IOException;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.model.ApplicationProperties;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2;
import stirling.software.SPDF.model.Provider;
import stirling.software.SPDF.model.provider.UnsupportedProviderException;
import stirling.software.SPDF.utils.UrlUtils;
@Slf4j
public class CustomOAuth2LogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
private final ApplicationProperties applicationProperties;
public CustomOAuth2LogoutSuccessHandler(ApplicationProperties applicationProperties) {
this.applicationProperties = applicationProperties;
}
@Override
public void onLogoutSuccess(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException, ServletException {
String param = "logout=true";
String registrationId = null;
String issuer = null;
String clientId = null;
if (authentication == null) {
if (request.getParameter("userIsDisabled") != null) {
response.sendRedirect(
request.getContextPath() + "/login?erroroauth=userIsDisabled");
} else {
super.onLogoutSuccess(request, response, authentication);
}
return;
}
OAUTH2 oauth = applicationProperties.getSecurity().getOauth2();
if (authentication instanceof OAuth2AuthenticationToken) {
OAuth2AuthenticationToken oauthToken = (OAuth2AuthenticationToken) authentication;
registrationId = oauthToken.getAuthorizedClientRegistrationId();
try {
Provider provider = oauth.getClient().get(registrationId);
issuer = provider.getIssuer();
clientId = provider.getClientId();
} catch (UnsupportedProviderException e) {
log.error(e.getMessage());
}
} else {
registrationId = oauth.getProvider() != null ? oauth.getProvider() : "";
issuer = oauth.getIssuer();
clientId = oauth.getClientId();
}
String errorMessage = "";
if (request.getParameter("oauth2AuthenticationErrorWeb") != null) {
param = "erroroauth=oauth2AuthenticationErrorWeb";
} else if ((errorMessage = request.getParameter("error")) != null) {
param = "error=" + sanitizeInput(errorMessage);
} else if ((errorMessage = request.getParameter("erroroauth")) != null) {
param = "erroroauth=" + sanitizeInput(errorMessage);
} else if (request.getParameter("oauth2AutoCreateDisabled") != null) {
param = "error=oauth2AutoCreateDisabled";
} else if (request.getParameter("oauth2_admin_blocked_user") != null) {
param = "erroroauth=oauth2_admin_blocked_user";
} else if (request.getParameter("userIsDisabled") != null) {
param = "erroroauth=userIsDisabled";
} else if (request.getParameter("badcredentials") != null) {
param = "error=badcredentials";
}
String redirect_url = UrlUtils.getOrigin(request) + "/login?" + param;
switch (registrationId.toLowerCase()) {
case "keycloak":
// Add Keycloak specific logout URL if needed
String logoutUrl =
issuer
+ "/protocol/openid-connect/logout"
+ "?client_id="
+ clientId
+ "&post_logout_redirect_uri="
+ response.encodeRedirectURL(redirect_url);
log.info("Redirecting to Keycloak logout URL: " + logoutUrl);
response.sendRedirect(logoutUrl);
break;
case "github":
// Add GitHub specific logout URL if needed
String githubLogoutUrl = "https://github.com/logout";
log.info("Redirecting to GitHub logout URL: " + githubLogoutUrl);
response.sendRedirect(githubLogoutUrl);
break;
case "google":
// Add Google specific logout URL if needed
// String googleLogoutUrl =
// "https://accounts.google.com/Logout?continue=https://appengine.google.com/_ah/logout?continue="
// + response.encodeRedirectURL(redirect_url);
log.info("Google does not have a specific logout URL");
// log.info("Redirecting to Google logout URL: " + googleLogoutUrl);
// response.sendRedirect(googleLogoutUrl);
// break;
default:
String defaultRedirectUrl = request.getContextPath() + "/login?" + param;
log.info("Redirecting to default logout URL: " + defaultRedirectUrl);
response.sendRedirect(defaultRedirectUrl);
break;
}
}
private String sanitizeInput(String input) {
return input.replaceAll("[^a-zA-Z0-9 ]", "");
}
}

View file

@ -1,68 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.opensaml.saml.saml2.core.Assertion;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken;
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import lombok.extern.slf4j.Slf4j;
@Component
@Slf4j
public class ConvertResponseToAuthentication
implements Converter<ResponseToken, Saml2Authentication> {
private final Saml2AuthorityAttributeLookup saml2AuthorityAttributeLookup;
public ConvertResponseToAuthentication(
Saml2AuthorityAttributeLookup saml2AuthorityAttributeLookup) {
this.saml2AuthorityAttributeLookup = saml2AuthorityAttributeLookup;
}
@Override
public Saml2Authentication convert(ResponseToken responseToken) {
final Assertion assertion =
CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
final Map<String, List<Object>> attributes =
SamlAssertionUtils.getAssertionAttributes(assertion);
final String registrationId =
responseToken.getToken().getRelyingPartyRegistration().getRegistrationId();
final ScimSaml2AuthenticatedPrincipal principal =
new ScimSaml2AuthenticatedPrincipal(
assertion,
attributes,
saml2AuthorityAttributeLookup.getIdentityMappings(registrationId));
final Collection<? extends GrantedAuthority> assertionAuthorities =
getAssertionAuthorities(
attributes,
saml2AuthorityAttributeLookup.getAuthorityAttribute(registrationId));
return new Saml2Authentication(
principal, responseToken.getToken().getSaml2Response(), assertionAuthorities);
}
private static Collection<? extends GrantedAuthority> getAssertionAuthorities(
final Map<String, List<Object>> attributes, final String authoritiesAttributeName) {
if (attributes == null || attributes.isEmpty()) {
return Collections.emptySet();
}
final List<Object> groups = new ArrayList<>(attributes.get(authoritiesAttributeName));
return groups.stream()
.filter(String.class::isInstance)
.map(String.class::cast)
.map(String::toLowerCase)
.map(SimpleGrantedAuthority::new)
.collect(Collectors.toSet());
}
}

View file

@ -1,51 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import java.io.IOException;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.DisabledException;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class CustomSAMLAuthenticationFailureHandler extends SimpleUrlAuthenticationFailureHandler {
@Override
public void onAuthenticationFailure(
HttpServletRequest request,
HttpServletResponse response,
AuthenticationException exception)
throws IOException, ServletException {
if (exception instanceof BadCredentialsException) {
log.error("BadCredentialsException", exception);
getRedirectStrategy().sendRedirect(request, response, "/login?error=badcredentials");
return;
}
if (exception instanceof DisabledException) {
log.error("User is deactivated: ", exception);
getRedirectStrategy().sendRedirect(request, response, "/logout?userIsDisabled=true");
return;
}
if (exception instanceof LockedException) {
log.error("Account locked: ", exception);
getRedirectStrategy().sendRedirect(request, response, "/logout?error=locked");
return;
}
if (exception instanceof Saml2AuthenticationException) {
log.error("SAML2 Authentication error: ", exception);
getRedirectStrategy()
.sendRedirect(request, response, "/logout?error=saml2AuthenticationError");
return;
}
log.error("Unhandled authentication exception", exception);
super.onAuthenticationFailure(request, response, exception);
}
}

View file

@ -1,108 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import java.io.IOException;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.SavedRequest;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.config.security.LoginAttemptService;
import stirling.software.SPDF.config.security.UserService;
import stirling.software.SPDF.model.ApplicationProperties;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2;
import stirling.software.SPDF.model.AuthenticationType;
import stirling.software.SPDF.utils.RequestUriUtils;
@Slf4j
public class CustomSAMLAuthenticationSuccessHandler
extends SavedRequestAwareAuthenticationSuccessHandler {
private LoginAttemptService loginAttemptService;
private UserService userService;
private ApplicationProperties applicationProperties;
public CustomSAMLAuthenticationSuccessHandler(
LoginAttemptService loginAttemptService,
UserService userService,
ApplicationProperties applicationProperties) {
this.loginAttemptService = loginAttemptService;
this.userService = userService;
this.applicationProperties = applicationProperties;
}
@Override
public void onAuthenticationSuccess(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws ServletException, IOException {
Object principal = authentication.getPrincipal();
String username = "";
if (principal instanceof OAuth2User) {
OAuth2User oauthUser = (OAuth2User) principal;
username = oauthUser.getName();
} else if (principal instanceof UserDetails) {
UserDetails oauthUser = (UserDetails) principal;
username = oauthUser.getUsername();
} else if (principal instanceof ScimSaml2AuthenticatedPrincipal) {
ScimSaml2AuthenticatedPrincipal samlPrincipal =
(ScimSaml2AuthenticatedPrincipal) principal;
username = samlPrincipal.getName();
}
// Get the saved request
HttpSession session = request.getSession(false);
String contextPath = request.getContextPath();
SavedRequest savedRequest =
(session != null)
? (SavedRequest) session.getAttribute("SPRING_SECURITY_SAVED_REQUEST")
: null;
if (savedRequest != null
&& !RequestUriUtils.isStaticResource(contextPath, savedRequest.getRedirectUrl())) {
// Redirect to the original destination
super.onAuthenticationSuccess(request, response, authentication);
} else {
OAUTH2 oAuth = applicationProperties.getSecurity().getOauth2();
if (loginAttemptService.isBlocked(username)) {
if (session != null) {
session.removeAttribute("SPRING_SECURITY_SAVED_REQUEST");
}
throw new LockedException(
"Your account has been locked due to too many failed login attempts.");
}
if (userService.usernameExistsIgnoreCase(username)
&& userService.hasPassword(username)
&& !userService.isAuthenticationTypeByUsername(
username, AuthenticationType.OAUTH2)
&& oAuth.getAutoCreateUser()) {
response.sendRedirect(contextPath + "/logout?oauth2AuthenticationErrorWeb=true");
return;
}
try {
if (oAuth.getBlockRegistration()
&& !userService.usernameExistsIgnoreCase(username)) {
response.sendRedirect(contextPath + "/logout?oauth2_admin_blocked_user=true");
return;
}
if (principal instanceof OAuth2User) {
userService.processOAuth2PostLogin(username, oAuth.getAutoCreateUser());
}
response.sendRedirect(contextPath + "/");
return;
} catch (IllegalArgumentException e) {
response.sendRedirect(contextPath + "/logout?invalidUsername=true");
return;
}
}
}
}

View file

@ -1,38 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import java.io.IOException;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class SAMLLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
@Override
public void onLogoutSuccess(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException, ServletException {
String redirectUrl = determineTargetUrl(request, response, authentication);
if (response.isCommitted()) {
log.debug("Response has already been committed. Unable to redirect to " + redirectUrl);
return;
}
getRedirectStrategy().sendRedirect(request, response, redirectUrl);
}
protected String determineTargetUrl(
HttpServletRequest request,
HttpServletResponse response,
Authentication authentication) {
// Default to the root URL
return "/";
}
}

View file

@ -1,7 +0,0 @@
package stirling.software.SPDF.config.security.saml;
public interface Saml2AuthorityAttributeLookup {
String getAuthorityAttribute(String registrationId);
SimpleScimMappings getIdentityMappings(String registrationId);
}

View file

@ -1,17 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import org.springframework.stereotype.Component;
@Component
public class Saml2AuthorityAttributeLookupImpl implements Saml2AuthorityAttributeLookup {
@Override
public String getAuthorityAttribute(String registrationId) {
return "authorityAttributeName";
}
@Override
public SimpleScimMappings getIdentityMappings(String registrationId) {
return new SimpleScimMappings();
}
}

View file

@ -1,63 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import java.time.Instant;
import java.util.*;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.schema.*;
import org.opensaml.saml.saml2.core.Assertion;
public class SamlAssertionUtils {
public static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
Map<String, List<Object>> attributeMap = new LinkedHashMap<>();
assertion
.getAttributeStatements()
.forEach(
attributeStatement -> {
attributeStatement
.getAttributes()
.forEach(
attribute -> {
List<Object> attributeValues = new ArrayList<>();
attribute
.getAttributeValues()
.forEach(
xmlObject -> {
Object attributeValue =
getXmlObjectValue(
xmlObject);
if (attributeValue != null) {
attributeValues.add(
attributeValue);
}
});
attributeMap.put(
attribute.getName(), attributeValues);
});
});
return attributeMap;
}
public static Object getXmlObjectValue(XMLObject xmlObject) {
if (xmlObject instanceof XSAny) {
return ((XSAny) xmlObject).getTextContent();
} else if (xmlObject instanceof XSString) {
return ((XSString) xmlObject).getValue();
} else if (xmlObject instanceof XSInteger) {
return ((XSInteger) xmlObject).getValue();
} else if (xmlObject instanceof XSURI) {
return ((XSURI) xmlObject).getURI();
} else if (xmlObject instanceof XSBoolean) {
return ((XSBoolean) xmlObject).getValue().getValue();
} else if (xmlObject instanceof XSDateTime) {
Instant dateTime = ((XSDateTime) xmlObject).getValue();
return (dateTime != null) ? Instant.ofEpochMilli(dateTime.toEpochMilli()) : null;
}
return null;
}
}

View file

@ -1,42 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import java.security.cert.CertificateException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrations;
import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.model.ApplicationProperties;
@Configuration
@Slf4j
public class SamlConfig {
@Autowired ApplicationProperties applicationProperties;
@Bean
@ConditionalOnProperty(
value = "security.saml.enabled",
havingValue = "true",
matchIfMissing = false)
public RelyingPartyRegistrationRepository relyingPartyRegistrationRepository()
throws CertificateException {
RelyingPartyRegistration registration =
RelyingPartyRegistrations.fromMetadataLocation(
applicationProperties
.getSecurity()
.getSaml()
.getIdpMetadataLocation())
.entityId(applicationProperties.getSecurity().getSaml().getEntityId())
.registrationId(
applicationProperties.getSecurity().getSaml().getRegistrationId())
.build();
return new InMemoryRelyingPartyRegistrationRepository(registration);
}
}

View file

@ -1,89 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.opensaml.saml.saml2.core.Assertion;
import org.springframework.security.core.AuthenticatedPrincipal;
import org.springframework.util.Assert;
import com.unboundid.scim2.common.types.Email;
import com.unboundid.scim2.common.types.Name;
import com.unboundid.scim2.common.types.UserResource;
public class ScimSaml2AuthenticatedPrincipal implements AuthenticatedPrincipal, Serializable {
private static final long serialVersionUID = 1L;
private final transient UserResource userResource;
public ScimSaml2AuthenticatedPrincipal(
final Assertion assertion,
final Map<String, List<Object>> attributes,
final SimpleScimMappings attributeMappings) {
Assert.notNull(assertion, "assertion cannot be null");
Assert.notNull(assertion.getSubject(), "assertion subject cannot be null");
Assert.notNull(
assertion.getSubject().getNameID(), "assertion subject NameID cannot be null");
Assert.notNull(attributes, "attributes cannot be null");
Assert.notNull(attributeMappings, "attributeMappings cannot be null");
final Name name =
new Name()
.setFamilyName(
getAttribute(
attributes,
attributeMappings,
SimpleScimMappings::getFamilyName))
.setGivenName(
getAttribute(
attributes,
attributeMappings,
SimpleScimMappings::getGivenName));
final List<Email> emails = new ArrayList<>(1);
emails.add(
new Email()
.setValue(
getAttribute(
attributes,
attributeMappings,
SimpleScimMappings::getEmail))
.setPrimary(true));
userResource =
new UserResource()
.setUserName(assertion.getSubject().getNameID().getValue())
.setName(name)
.setEmails(emails);
}
private static String getAttribute(
final Map<String, List<Object>> attributes,
final SimpleScimMappings simpleScimMappings,
final Function<SimpleScimMappings, String> attributeMapper) {
final String key = attributeMapper.apply(simpleScimMappings);
final List<Object> values = attributes.getOrDefault(key, Collections.emptyList());
return values.stream()
.filter(String.class::isInstance)
.map(String.class::cast)
.findFirst()
.orElse(null);
}
@Override
public String getName() {
return this.userResource.getUserName();
}
public UserResource getUserResource() {
return this.userResource;
}
}

View file

@ -1,10 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import lombok.Data;
@Data
public class SimpleScimMappings {
String givenName;
String familyName;
String email;
}

View file

@ -0,0 +1,48 @@
package stirling.software.SPDF.config.security.saml2;
import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
import java.security.KeyFactory;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPrivateKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Base64;
import org.springframework.core.io.Resource;
import org.springframework.util.FileCopyUtils;
public class CertificateUtils {
public static X509Certificate readCertificate(Resource certificateResource) throws Exception {
String certificateString =
new String(
FileCopyUtils.copyToByteArray(certificateResource.getInputStream()),
StandardCharsets.UTF_8);
String certContent =
certificateString
.replace("-----BEGIN CERTIFICATE-----", "")
.replace("-----END CERTIFICATE-----", "")
.replaceAll("\\R", "")
.replaceAll("\\s+", "");
CertificateFactory cf = CertificateFactory.getInstance("X.509");
byte[] decodedCert = Base64.getDecoder().decode(certContent);
return (X509Certificate) cf.generateCertificate(new ByteArrayInputStream(decodedCert));
}
public static RSAPrivateKey readPrivateKey(Resource privateKeyResource) throws Exception {
String privateKeyString =
new String(
FileCopyUtils.copyToByteArray(privateKeyResource.getInputStream()),
StandardCharsets.UTF_8);
String privateKeyContent =
privateKeyString
.replace("-----BEGIN PRIVATE KEY-----", "")
.replace("-----END PRIVATE KEY-----", "")
.replaceAll("\\R", "")
.replaceAll("\\s+", "");
KeyFactory kf = KeyFactory.getInstance("RSA");
byte[] decodedKey = Base64.getDecoder().decode(privateKeyContent);
return (RSAPrivateKey) kf.generatePrivate(new PKCS8EncodedKeySpec(decodedKey));
}
}

View file

@ -0,0 +1,45 @@
package stirling.software.SPDF.config.security.saml2;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
public class CustomSaml2AuthenticatedPrincipal
implements Saml2AuthenticatedPrincipal, Serializable {
private final String name;
private final Map<String, List<Object>> attributes;
private final String nameId;
private final List<String> sessionIndexes;
public CustomSaml2AuthenticatedPrincipal(
String name,
Map<String, List<Object>> attributes,
String nameId,
List<String> sessionIndexes) {
this.name = name;
this.attributes = attributes;
this.nameId = nameId;
this.sessionIndexes = sessionIndexes;
}
@Override
public String getName() {
return this.name;
}
@Override
public Map<String, List<Object>> getAttributes() {
return this.attributes;
}
public String getNameId() {
return this.nameId;
}
public List<String> getSessionIndexes() {
return this.sessionIndexes;
}
}

View file

@ -0,0 +1,38 @@
package stirling.software.SPDF.config.security.saml2;
import java.io.IOException;
import org.springframework.security.authentication.ProviderNotFoundException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class CustomSaml2AuthenticationFailureHandler extends SimpleUrlAuthenticationFailureHandler {
@Override
public void onAuthenticationFailure(
HttpServletRequest request,
HttpServletResponse response,
AuthenticationException exception)
throws IOException, ServletException {
if (exception instanceof Saml2AuthenticationException) {
Saml2Error error = ((Saml2AuthenticationException) exception).getSaml2Error();
getRedirectStrategy()
.sendRedirect(request, response, "/login?erroroauth=" + error.getErrorCode());
} else if (exception instanceof ProviderNotFoundException) {
getRedirectStrategy()
.sendRedirect(
request,
response,
"/login?erroroauth=not_authentication_provider_found");
}
log.error("AuthenticationException: " + exception);
}
}

View file

@ -0,0 +1,91 @@
package stirling.software.SPDF.config.security.saml2;
import java.io.IOException;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.SavedRequest;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import lombok.AllArgsConstructor;
import stirling.software.SPDF.config.security.LoginAttemptService;
import stirling.software.SPDF.config.security.UserService;
import stirling.software.SPDF.model.ApplicationProperties;
import stirling.software.SPDF.model.ApplicationProperties.Security.SAML2;
import stirling.software.SPDF.model.AuthenticationType;
import stirling.software.SPDF.utils.RequestUriUtils;
@AllArgsConstructor
public class CustomSaml2AuthenticationSuccessHandler
extends SavedRequestAwareAuthenticationSuccessHandler {
private LoginAttemptService loginAttemptService;
private ApplicationProperties applicationProperties;
private UserService userService;
@Override
public void onAuthenticationSuccess(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws ServletException, IOException {
Object principal = authentication.getPrincipal();
if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
String username = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
// Get the saved request
HttpSession session = request.getSession(false);
String contextPath = request.getContextPath();
SavedRequest savedRequest =
(session != null)
? (SavedRequest) session.getAttribute("SPRING_SECURITY_SAVED_REQUEST")
: null;
if (savedRequest != null
&& !RequestUriUtils.isStaticResource(
contextPath, savedRequest.getRedirectUrl())) {
// Redirect to the original destination
super.onAuthenticationSuccess(request, response, authentication);
} else {
SAML2 saml2 = applicationProperties.getSecurity().getSaml2();
if (loginAttemptService.isBlocked(username)) {
if (session != null) {
session.removeAttribute("SPRING_SECURITY_SAVED_REQUEST");
}
throw new LockedException(
"Your account has been locked due to too many failed login attempts.");
}
if (userService.usernameExistsIgnoreCase(username)
&& userService.hasPassword(username)
&& !userService.isAuthenticationTypeByUsername(
username, AuthenticationType.OAUTH2)
&& saml2.getAutoCreateUser()) {
response.sendRedirect(
contextPath + "/logout?oauth2AuthenticationErrorWeb=true");
return;
}
try {
if (saml2.getBlockRegistration()
&& !userService.usernameExistsIgnoreCase(username)) {
response.sendRedirect(
contextPath + "/login?erroroauth=oauth2_admin_blocked_user");
return;
}
userService.processOAuth2PostLogin(username, saml2.getAutoCreateUser());
response.sendRedirect(contextPath + "/");
return;
} catch (IllegalArgumentException e) {
response.sendRedirect(contextPath + "/logout?invalidUsername=true");
return;
}
}
} else {
super.onAuthenticationSuccess(request, response, authentication);
}
}
}

View file

@ -0,0 +1,86 @@
package stirling.software.SPDF.config.security.saml2;
import java.util.*;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.schema.XSBoolean;
import org.opensaml.core.xml.schema.XSString;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AuthnStatement;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken;
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
import org.springframework.stereotype.Component;
import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.config.security.UserService;
import stirling.software.SPDF.model.User;
@Component
@Slf4j
public class CustomSaml2ResponseAuthenticationConverter
implements Converter<ResponseToken, Saml2Authentication> {
private UserService userService;
public CustomSaml2ResponseAuthenticationConverter(UserService userService) {
this.userService = userService;
}
@Override
public Saml2Authentication convert(ResponseToken responseToken) {
// Extract the assertion from the response
Assertion assertion = responseToken.getResponse().getAssertions().get(0);
// Extract the NameID
String nameId = assertion.getSubject().getNameID().getValue();
Optional<User> userOpt = userService.findByUsernameIgnoreCase(nameId);
SimpleGrantedAuthority simpleGrantedAuthority = new SimpleGrantedAuthority("ROLE_USER");
if (userOpt.isPresent()) {
User user = userOpt.get();
if (user != null) {
simpleGrantedAuthority =
new SimpleGrantedAuthority(userService.findRole(user).getAuthority());
}
}
// Extract the SessionIndexes
List<String> sessionIndexes = new ArrayList<>();
for (AuthnStatement authnStatement : assertion.getAuthnStatements()) {
sessionIndexes.add(authnStatement.getSessionIndex());
}
// Extract the Attributes
Map<String, List<Object>> attributes = extractAttributes(assertion);
// Create the custom principal
CustomSaml2AuthenticatedPrincipal principal =
new CustomSaml2AuthenticatedPrincipal(nameId, attributes, nameId, sessionIndexes);
// Create the Saml2Authentication
return new Saml2Authentication(
principal,
responseToken.getToken().getSaml2Response(),
Collections.singletonList(simpleGrantedAuthority));
}
private Map<String, List<Object>> extractAttributes(Assertion assertion) {
Map<String, List<Object>> attributes = new HashMap<>();
for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) {
for (Attribute attribute : attributeStatement.getAttributes()) {
String attributeName = attribute.getName();
List<Object> values = new ArrayList<>();
for (XMLObject xmlObject : attribute.getAttributeValues()) {
log.info("BOOL: " + ((XSBoolean) xmlObject).getValue());
values.add(((XSString) xmlObject).getValue());
}
attributes.put(attributeName, values);
}
}
return attributes;
}
}

View file

@ -16,6 +16,7 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import jakarta.transaction.Transactional; import jakarta.transaction.Transactional;
import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrincipal;
import stirling.software.SPDF.model.SessionEntity; import stirling.software.SPDF.model.SessionEntity;
@Component @Component
@ -50,6 +51,8 @@ public class SessionPersistentRegistry implements SessionRegistry {
principalName = ((UserDetails) principal).getUsername(); principalName = ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) { } else if (principal instanceof OAuth2User) {
principalName = ((OAuth2User) principal).getName(); principalName = ((OAuth2User) principal).getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
principalName = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
} else if (principal instanceof String) { } else if (principal instanceof String) {
principalName = (String) principal; principalName = (String) principal;
} }
@ -79,6 +82,8 @@ public class SessionPersistentRegistry implements SessionRegistry {
principalName = ((UserDetails) principal).getUsername(); principalName = ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) { } else if (principal instanceof OAuth2User) {
principalName = ((OAuth2User) principal).getName(); principalName = ((OAuth2User) principal).getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
principalName = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
} else if (principal instanceof String) { } else if (principal instanceof String) {
principalName = (String) principal; principalName = (String) principal;
} }

View file

@ -32,6 +32,7 @@ import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.config.security.UserService; import stirling.software.SPDF.config.security.UserService;
import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrincipal;
import stirling.software.SPDF.config.security.session.SessionPersistentRegistry; import stirling.software.SPDF.config.security.session.SessionPersistentRegistry;
import stirling.software.SPDF.model.AuthenticationType; import stirling.software.SPDF.model.AuthenticationType;
import stirling.software.SPDF.model.Role; import stirling.software.SPDF.model.Role;
@ -336,6 +337,8 @@ public class UserController {
userNameP = ((UserDetails) principal).getUsername(); userNameP = ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) { } else if (principal instanceof OAuth2User) {
userNameP = ((OAuth2User) principal).getName(); userNameP = ((OAuth2User) principal).getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
userNameP = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
} else if (principal instanceof String) { } else if (principal instanceof String) {
userNameP = (String) principal; userNameP = (String) principal;
} }

View file

@ -21,10 +21,13 @@ import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrincipal;
import stirling.software.SPDF.config.security.session.SessionPersistentRegistry; import stirling.software.SPDF.config.security.session.SessionPersistentRegistry;
import stirling.software.SPDF.model.*; import stirling.software.SPDF.model.*;
import stirling.software.SPDF.model.ApplicationProperties.Security;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2; import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2.Client; import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2.Client;
import stirling.software.SPDF.model.ApplicationProperties.Security.SAML2;
import stirling.software.SPDF.model.provider.GithubProvider; import stirling.software.SPDF.model.provider.GithubProvider;
import stirling.software.SPDF.model.provider.GoogleProvider; import stirling.software.SPDF.model.provider.GoogleProvider;
import stirling.software.SPDF.model.provider.KeycloakProvider; import stirling.software.SPDF.model.provider.KeycloakProvider;
@ -51,38 +54,54 @@ public class AccountWebController {
Map<String, String> providerList = new HashMap<>(); Map<String, String> providerList = new HashMap<>();
OAUTH2 oauth = applicationProperties.getSecurity().getOauth2(); Security securityProps = applicationProperties.getSecurity();
OAUTH2 oauth = securityProps.getOauth2();
if (oauth != null) { if (oauth != null) {
if (oauth.getEnabled()) {
if (oauth.isSettingsValid()) { if (oauth.isSettingsValid()) {
providerList.put("oidc", oauth.getProvider()); providerList.put("/oauth2/authorization/oidc", oauth.getProvider());
} }
Client client = oauth.getClient(); Client client = oauth.getClient();
if (client != null) { if (client != null) {
GoogleProvider google = client.getGoogle(); GoogleProvider google = client.getGoogle();
if (google.isSettingsValid()) { if (google.isSettingsValid()) {
providerList.put(google.getName(), google.getClientName()); providerList.put(
"/oauth2/authorization/" + google.getName(),
google.getClientName());
} }
GithubProvider github = client.getGithub(); GithubProvider github = client.getGithub();
if (github.isSettingsValid()) { if (github.isSettingsValid()) {
providerList.put(github.getName(), github.getClientName()); providerList.put(
"/oauth2/authorization/" + github.getName(),
github.getClientName());
} }
KeycloakProvider keycloak = client.getKeycloak(); KeycloakProvider keycloak = client.getKeycloak();
if (keycloak.isSettingsValid()) { if (keycloak.isSettingsValid()) {
providerList.put(keycloak.getName(), keycloak.getClientName()); providerList.put(
"/oauth2/authorization/" + keycloak.getName(),
keycloak.getClientName());
} }
} }
} }
}
SAML2 saml2 = securityProps.getSaml2();
if (saml2 != null) {
if (saml2.getEnabled()) {
providerList.put("/saml2/authenticate/" + saml2.getRegistrationId(), "SAML 2");
}
}
// Remove any null keys/values from the providerList // Remove any null keys/values from the providerList
providerList providerList
.entrySet() .entrySet()
.removeIf(entry -> entry.getKey() == null || entry.getValue() == null); .removeIf(entry -> entry.getKey() == null || entry.getValue() == null);
model.addAttribute("providerlist", providerList); model.addAttribute("providerlist", providerList);
model.addAttribute("loginMethod", applicationProperties.getSecurity().getLoginMethod()); model.addAttribute("loginMethod", securityProps.getLoginMethod());
model.addAttribute( model.addAttribute("altLogin", securityProps.isAltLogin());
"oAuth2Enabled", applicationProperties.getSecurity().getOauth2().getEnabled());
model.addAttribute("currentPage", "login"); model.addAttribute("currentPage", "login");
@ -349,6 +368,17 @@ public class AccountWebController {
// Add oAuth2 Login attributes to the model // Add oAuth2 Login attributes to the model
model.addAttribute("oAuth2Login", true); model.addAttribute("oAuth2Login", true);
} }
if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
// Cast the principal object to OAuth2User
CustomSaml2AuthenticatedPrincipal userDetails =
(CustomSaml2AuthenticatedPrincipal) principal;
// Retrieve username and other attributes
username = userDetails.getName();
// Add oAuth2 Login attributes to the model
model.addAttribute("oAuth2Login", true);
}
if (username != null) { if (username != null) {
// Fetch user details from the database // Fetch user details from the database
Optional<User> user = Optional<User> user =

View file

@ -1,13 +1,17 @@
package stirling.software.SPDF.model; package stirling.software.SPDF.model;
import java.io.IOException;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.PropertySource; import org.springframework.context.annotation.PropertySource;
@ -18,6 +22,8 @@ import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource; import org.springframework.core.io.Resource;
import lombok.Data; import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString; import lombok.ToString;
import stirling.software.SPDF.config.YamlPropertySourceFactory; import stirling.software.SPDF.config.YamlPropertySourceFactory;
import stirling.software.SPDF.model.provider.GithubProvider; import stirling.software.SPDF.model.provider.GithubProvider;
@ -41,7 +47,6 @@ public class ApplicationProperties {
private AutomaticallyGenerated automaticallyGenerated = new AutomaticallyGenerated(); private AutomaticallyGenerated automaticallyGenerated = new AutomaticallyGenerated();
private EnterpriseEdition enterpriseEdition = new EnterpriseEdition(); private EnterpriseEdition enterpriseEdition = new EnterpriseEdition();
private AutoPipeline autoPipeline = new AutoPipeline(); private AutoPipeline autoPipeline = new AutoPipeline();
private static final Logger logger = LoggerFactory.getLogger(ApplicationProperties.class);
@Data @Data
public static class AutoPipeline { public static class AutoPipeline {
@ -63,41 +68,108 @@ public class ApplicationProperties {
private Boolean csrfDisabled; private Boolean csrfDisabled;
private InitialLogin initialLogin = new InitialLogin(); private InitialLogin initialLogin = new InitialLogin();
private OAUTH2 oauth2 = new OAUTH2(); private OAUTH2 oauth2 = new OAUTH2();
private SAML saml = new SAML(); private SAML2 saml2 = new SAML2();
private int loginAttemptCount; private int loginAttemptCount;
private long loginResetTimeMinutes; private long loginResetTimeMinutes;
private String loginMethod = "all"; private String loginMethod = "all";
public Boolean isAltLogin() {
return saml2.getEnabled() || oauth2.getEnabled();
}
public enum LoginMethods {
ALL("all"),
NORMAL("normal"),
OAUTH2("oauth2"),
SAML2("saml2");
private String method;
LoginMethods(String method) {
this.method = method;
}
@Override
public String toString() {
return method;
}
}
public boolean isUserPass() {
return (loginMethod.equalsIgnoreCase(LoginMethods.NORMAL.toString())
|| loginMethod.equalsIgnoreCase(LoginMethods.ALL.toString()));
}
public boolean isOauth2Activ() {
return (oauth2 != null
&& oauth2.getEnabled()
&& !loginMethod.equalsIgnoreCase(LoginMethods.NORMAL.toString()));
}
public boolean isSaml2Activ() {
return (saml2 != null
&& saml2.getEnabled()
&& !loginMethod.equalsIgnoreCase(LoginMethods.NORMAL.toString()));
}
@Data @Data
public static class InitialLogin { public static class InitialLogin {
private String username; private String username;
@ToString.Exclude private String password; @ToString.Exclude private String password;
} }
@Data @Getter
public static class SAML { @Setter
public static class SAML2 {
private Boolean enabled = false; private Boolean enabled = false;
private String entityId; private Boolean autoCreateUser = false;
private String registrationId; private Boolean blockRegistration = false;
private String spBaseUrl; private String registrationId = "stirling";
private String idpMetadataLocation; private String idpMetadataUri;
private KeyStore keystore; private String idpSingleLogoutUrl;
private String idpSingleLoginUrl;
private String idpIssuer;
private String idpCert;
private String privateKey;
private String spCert;
@Data public InputStream getIdpMetadataUri() throws IOException {
public static class KeyStore { if (idpMetadataUri.startsWith("classpath:")) {
private String keystoreLocation; return new ClassPathResource(idpMetadataUri.substring("classpath".length()))
private String keystorePassword; .getInputStream();
private String keyAlias;
private String keyPassword;
private String realmCertificateAlias;
public Resource getKeystoreResource() {
if (keystoreLocation.startsWith("classpath:")) {
return new ClassPathResource(
keystoreLocation.substring("classpath:".length()));
} else {
return new FileSystemResource(keystoreLocation);
} }
try {
URI uri = new URI(idpMetadataUri);
URL url = uri.toURL();
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("GET");
return connection.getInputStream();
} catch (URISyntaxException e) {
throw new IOException("Invalid URI format: " + idpMetadataUri, e);
}
}
public Resource getSpCert() {
if (spCert.startsWith("classpath:")) {
return new ClassPathResource(spCert.substring("classpath:".length()));
} else {
return new FileSystemResource(spCert);
}
}
public Resource getidpCert() {
if (idpCert.startsWith("classpath:")) {
return new ClassPathResource(idpCert.substring("classpath:".length()));
} else {
return new FileSystemResource(idpCert);
}
}
public Resource getPrivateKey() {
if (privateKey.startsWith("classpath:")) {
return new ClassPathResource(privateKey.substring("classpath:".length()));
} else {
return new FileSystemResource(privateKey);
} }
} }
} }

View file

@ -19,7 +19,6 @@ public class Provider implements ProviderInterface {
return true; return true;
} }
return false; return false;
// throw new IllegalArgumentException(getName() + ": " + name + " is required!");
} }
protected boolean isValid(Collection<String> value, String name) { protected boolean isValid(Collection<String> value, String name) {
@ -27,66 +26,55 @@ public class Provider implements ProviderInterface {
return true; return true;
} }
return false; return false;
// throw new IllegalArgumentException(getName() + ": " + name + " is required!");
} }
@Override @Override
public Collection<String> getScopes() { public Collection<String> getScopes() {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'getScope'"); throw new UnsupportedOperationException("Unimplemented method 'getScope'");
} }
@Override @Override
public void setScopes(String scopes) { public void setScopes(String scopes) {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'setScope'"); throw new UnsupportedOperationException("Unimplemented method 'setScope'");
} }
@Override @Override
public String getUseAsUsername() { public String getUseAsUsername() {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'getUseAsUsername'"); throw new UnsupportedOperationException("Unimplemented method 'getUseAsUsername'");
} }
@Override @Override
public void setUseAsUsername(String useAsUsername) { public void setUseAsUsername(String useAsUsername) {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'setUseAsUsername'"); throw new UnsupportedOperationException("Unimplemented method 'setUseAsUsername'");
} }
@Override @Override
public String getIssuer() { public String getIssuer() {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'getIssuer'"); throw new UnsupportedOperationException("Unimplemented method 'getIssuer'");
} }
@Override @Override
public void setIssuer(String issuer) { public void setIssuer(String issuer) {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'setIssuer'"); throw new UnsupportedOperationException("Unimplemented method 'setIssuer'");
} }
@Override @Override
public String getClientSecret() { public String getClientSecret() {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'getClientSecret'"); throw new UnsupportedOperationException("Unimplemented method 'getClientSecret'");
} }
@Override @Override
public void setClientSecret(String clientSecret) { public void setClientSecret(String clientSecret) {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'setClientSecret'"); throw new UnsupportedOperationException("Unimplemented method 'setClientSecret'");
} }
@Override @Override
public String getClientId() { public String getClientId() {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'getClientId'"); throw new UnsupportedOperationException("Unimplemented method 'getClientId'");
} }
@Override @Override
public void setClientId(String clientId) { public void setClientId(String clientId) {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'setClientId'"); throw new UnsupportedOperationException("Unimplemented method 'setClientId'");
} }
} }

View file

@ -47,6 +47,18 @@ security:
useAsUsername: email # Default is 'email'; custom fields can be used as the username useAsUsername: email # Default is 'email'; custom fields can be used as the username
scopes: openid, profile, email # Specify the scopes for which the application will request permissions scopes: openid, profile, email # Specify the scopes for which the application will request permissions
provider: google # Set this to your OAuth provider's name, e.g., 'google' or 'keycloak' provider: google # Set this to your OAuth provider's name, e.g., 'google' or 'keycloak'
saml2:
enabled: false
autoCreateUser: false # set to 'true' to allow auto-creation of non-existing users
blockRegistration: false # set to 'true' to deny login with SSO without prior registration by an admin
registrationId: stirling
idpMetadataUri: https://dev-XXXXXXXX.okta.com/app/externalKey/sso/saml/metadata
idpSingleLogoutUrl: https://dev-XXXXXXXX.okta.com/app/dev-XXXXXXXX_stirlingpdf_1/externalKey/slo/saml
idpSingleLoginUrl: https://dev-XXXXXXXX.okta.com/app/dev-XXXXXXXX_stirlingpdf_1/externalKey/sso/saml
idpIssuer: http://www.okta.com/externalKey
idpCert: classpath:octa.crt
privateKey: classpath:saml-private-key.key
spCert: classpath:saml-public-cert.crt
# Enterprise edition settings unused for now please ignore! # Enterprise edition settings unused for now please ignore!
enterpriseEdition: enterpriseEdition:

View file

@ -283,7 +283,7 @@
</script> </script>
<th:block th:insert="~{fragments/footer.html :: footer}"></th:block> <th:block th:insert="~{fragments/footer.html :: footer}"></th:block>
</div> </div>
<div th:if="${oAuth2Enabled}" class="modal fade" id="editUserModal" tabindex="-1" role="dialog" aria-labelledby="editUserModalLabel" aria-hidden="true"> <div th:if="${altLogin}" class="modal fade" id="editUserModal" tabindex="-1" role="dialog" aria-labelledby="editUserModalLabel" aria-hidden="true">
<div class="modal-dialog modal-dialog-centered" role="document"> <div class="modal-dialog modal-dialog-centered" role="document">
<div class="modal-content"> <div class="modal-content">
<div class="modal-header"> <div class="modal-header">

View file

@ -114,7 +114,7 @@
<img class="my-4" th:src="@{'/favicon.svg'}" alt="favicon" width="144" height="144"> <img class="my-4" th:src="@{'/favicon.svg'}" alt="favicon" width="144" height="144">
<h1 class="h1 mb-3 fw-normal" th:text="${@appName}">Stirling-PDF</h1> <h1 class="h1 mb-3 fw-normal" th:text="${@appName}">Stirling-PDF</h1>
<div th:if="${oAuth2Enabled} and (${loginMethod} == 'all' or ${loginMethod} == 'oauth2')"> <div th:if="${altLogin} and (${loginMethod} == 'all' or ${loginMethod} == 'oauth2')">
<a href="#" class="w-100 btn btn-lg btn-primary" data-bs-toggle="modal" data-bs-target="#loginsModal" th:text="#{login.ssoSignIn}">Login Via SSO</a> <a href="#" class="w-100 btn btn-lg btn-primary" data-bs-toggle="modal" data-bs-target="#loginsModal" th:text="#{login.ssoSignIn}">Login Via SSO</a>
<br> <br>
<br> <br>
@ -168,7 +168,7 @@
</main> </main>
<th:block th:insert="~{fragments/footer.html :: footer}"></th:block> <th:block th:insert="~{fragments/footer.html :: footer}"></th:block>
</div> </div>
<div th:if="${oAuth2Enabled}" class="modal fade" id="loginsModal" tabindex="-1" role="dialog" aria-labelledby="loginsModalLabel" aria-hidden="true"> <div th:if="${altLogin}" class="modal fade" id="loginsModal" tabindex="-1" role="dialog" aria-labelledby="loginsModalLabel" aria-hidden="true">
<div class="modal-dialog modal-dialog-centered" role="document"> <div class="modal-dialog modal-dialog-centered" role="document">
<div class="modal-content"> <div class="modal-content">
<div class="modal-header"> <div class="modal-header">
@ -181,7 +181,7 @@
</div> </div>
<div class="modal-body"> <div class="modal-body">
<div class="mb-3" th:each="provider : ${providerlist}"> <div class="mb-3" th:each="provider : ${providerlist}">
<a th:href="@{|/oauth2/authorization/${provider.key}|}" th:text="${provider.value}" class="w-100 btn btn-lg btn-primary">OpenID Connect</a> <a th:href="@{|${provider.key}|}" th:text="${provider.value}" class="w-100 btn btn-lg btn-primary">Login Provider</a>
</div> </div>
</div> </div>
<div class="modal-footer"> <div class="modal-footer">