APi key stuff

This commit is contained in:
Anthony Stirling 2023-08-13 10:53:00 +01:00
parent ab9a22d8e7
commit 7f7ea6da9f
6 changed files with 163 additions and 37 deletions

View file

@ -16,6 +16,7 @@ import org.springframework.web.filter.OncePerRequestFilter;
import io.github.bucket4j.Bandwidth; import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.Bucket; import io.github.bucket4j.Bucket;
import io.github.bucket4j.Bucket4j; import io.github.bucket4j.Bucket4j;
import io.github.bucket4j.ConsumptionProbe;
import io.github.bucket4j.Refill; import io.github.bucket4j.Refill;
import jakarta.servlet.FilterChain; import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException; import jakarta.servlet.ServletException;
@ -26,18 +27,22 @@ public class UserBasedRateLimitingFilter extends OncePerRequestFilter {
private final Map<String, Bucket> buckets = new ConcurrentHashMap<>(); private final Map<String, Bucket> buckets = new ConcurrentHashMap<>();
@Autowired
private UserDetailsService userDetailsService;
@Override @Override
protected void doFilterInternal(HttpServletRequest request, protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException { FilterChain filterChain) throws ServletException, IOException {
String method = request.getMethod(); String method = request.getMethod();
if (!"POST".equalsIgnoreCase(method)) { if (!"POST".equalsIgnoreCase(method)) {
// If the request is not a POST, just pass it through without rate limiting
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
return; return;
} }
String identifier; String identifier;
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
@ -49,20 +54,25 @@ public class UserBasedRateLimitingFilter extends OncePerRequestFilter {
} }
Bucket userBucket = buckets.computeIfAbsent(identifier, k -> createUserBucket()); Bucket userBucket = buckets.computeIfAbsent(identifier, k -> createUserBucket());
ConsumptionProbe probe = userBucket.tryConsumeAndReturnRemaining(1);
if (userBucket.tryConsume(1)) { if (probe.isConsumed()) {
response.setHeader("X-Rate-Limit-Remaining", Long.toString(probe.getRemainingTokens()));
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
} else { } else {
long waitForRefill = probe.getNanosToWaitForRefill() / 1_000_000_000;
response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value()); response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value());
response.getWriter().write("Rate limit exceeded."); response.setHeader("X-Rate-Limit-Retry-After-Seconds", String.valueOf(waitForRefill));
response.getWriter().write("Rate limit exceeded for POST requests.");
return; return;
} }
} }
//https://www.baeldung.com/spring-bucket4j
private Bucket createUserBucket() { private Bucket createUserBucket() {
Refill refill = Refill.of(3, Duration.ofDays(1)); Bandwidth limit = Bandwidth.classic(1000, Refill.intervally(1000, Duration.ofDays(1)));
Bandwidth limit = Bandwidth.classic(3, refill).withInitialTokens(3); return Bucket.builder().addLimit(limit).build();
return Bucket4j.builder().addLimit(limit).build();
}
} }
}

View file

@ -18,9 +18,9 @@ public class InitialSetup {
String initialUsername = System.getenv("INITIAL_USERNAME"); String initialUsername = System.getenv("INITIAL_USERNAME");
String initialPassword = System.getenv("INITIAL_PASSWORD"); String initialPassword = System.getenv("INITIAL_PASSWORD");
if(initialUsername != null && initialPassword != null) { if(initialUsername != null && initialPassword != null) {
userService.saveUser(initialUsername, initialPassword, Role.ADMIN); userService.saveUser(initialUsername, initialPassword, Role.ADMIN.getRoleId());
} else { } else {
userService.saveUser("admin", "password", Role.ADMIN); userService.saveUser("admin", "password", Role.ADMIN.getRoleId());
} }
} }
} }

View file

@ -1,10 +1,20 @@
package stirling.software.SPDF.config.security; package stirling.software.SPDF.config.security;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -21,6 +31,68 @@ public class UserService {
@Autowired @Autowired
private PasswordEncoder passwordEncoder; private PasswordEncoder passwordEncoder;
public Authentication getAuthentication(String apiKey) {
User user = getUserByApiKey(apiKey);
if (user == null) {
throw new UsernameNotFoundException("API key is not valid");
}
// Convert the user into an Authentication object
return new UsernamePasswordAuthenticationToken(
user, // principal (typically the user)
null, // credentials (we don't expose the password or API key here)
getAuthorities(user) // user's authorities (roles/permissions)
);
}
private Collection<? extends GrantedAuthority> getAuthorities(User user) {
// Convert each Authority object into a SimpleGrantedAuthority object.
return user.getAuthorities().stream()
.map((Authority authority) -> new SimpleGrantedAuthority(authority.getAuthority()))
.collect(Collectors.toList());
}
private String generateApiKey() {
String apiKey;
do {
apiKey = UUID.randomUUID().toString();
} while (userRepository.findByApiKey(apiKey) != null); // Ensure uniqueness
return apiKey;
}
public User addApiKeyToUser(String username) {
User user = userRepository.findByUsername(username)
.orElseThrow(() -> new UsernameNotFoundException("User not found"));
user.setApiKey(generateApiKey());
return userRepository.save(user);
}
public User refreshApiKeyForUser(String username) {
return addApiKeyToUser(username); // reuse the add API key method for refreshing
}
public String getApiKeyForUser(String username) {
User user = userRepository.findByUsername(username)
.orElseThrow(() -> new UsernameNotFoundException("User not found"));
return user.getApiKey();
}
public boolean isValidApiKey(String apiKey) {
return userRepository.findByApiKey(apiKey) != null;
}
public User getUserByApiKey(String apiKey) {
return userRepository.findByApiKey(apiKey);
}
public boolean validateApiKeyForUser(String username, String apiKey) {
Optional<User> userOpt = userRepository.findByUsername(username);
return userOpt.isPresent() && userOpt.get().getApiKey().equals(apiKey);
}
public void saveUser(String username, String password) { public void saveUser(String username, String password) {
User user = new User(); User user = new User();
user.setUsername(username); user.setUsername(username);

View file

@ -1,10 +1,42 @@
package stirling.software.SPDF.model; package stirling.software.SPDF.model;
public final class Role { public enum Role {
public static final String ADMIN = "ROLE_ADMIN"; // Unlimited access
public static final String USER = "ROLE_USER"; ADMIN("ROLE_ADMIN", Integer.MAX_VALUE, Integer.MAX_VALUE),
public static final String LIMITED_API_USER = "ROLE_LIMITED_API_USER";
public static final String WEB_ONLY_USER = "ROLE_WEB_ONLY_USER"; // Unlimited access
USER("ROLE_USER", Integer.MAX_VALUE, Integer.MAX_VALUE),
// 40 API calls Per Day, 40 web calls
LIMITED_API_USER("ROLE_LIMITED_API_USER", 40, 40),
// 20 API calls Per Day, 20 web calls
EXTRA_LIMITED_API_USER("ROLE_EXTRA_LIMITED_API_USER", 20, 20),
// 0 API calls per day and 20 web calls
WEB_ONLY_USER("ROLE_WEB_ONLY_USER", 0, 20);
private final String roleId;
private final int apiCallsPerDay;
private final int webCallsPerDay;
Role(String roleId, int apiCallsPerDay, int webCallsPerDay) {
this.roleId = roleId;
this.apiCallsPerDay = apiCallsPerDay;
this.webCallsPerDay = webCallsPerDay;
}
public String getRoleId() {
return roleId;
}
public int getApiCallsPerDay() {
return apiCallsPerDay;
}
public int getWebCallsPerDay() {
return webCallsPerDay;
}

View file

@ -28,6 +28,9 @@ public class User {
@Column(name = "password") @Column(name = "password")
private String password; private String password;
@Column(name = "apiKey")
private String apiKey;
@Column(name = "enabled") @Column(name = "enabled")
private boolean enabled; private boolean enabled;
@ -42,6 +45,14 @@ public class User {
public String getApiKey() {
return apiKey;
}
public void setApiKey(String apiKey) {
this.apiKey = apiKey;
}
public Map<String, String> getSettings() { public Map<String, String> getSettings() {
return settings; return settings;
} }

View file

@ -8,5 +8,6 @@ import stirling.software.SPDF.model.User;
public interface UserRepository extends JpaRepository<User, String> { public interface UserRepository extends JpaRepository<User, String> {
Optional<User> findByUsername(String username); Optional<User> findByUsername(String username);
User findByApiKey(String apiKey);
} }