1use camino::Utf8Path;
53use serde::Deserialize;
54
55use crate::http::SourceAllowlist;
56
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
67pub struct HostPattern(String);
68
69impl HostPattern {
70 pub fn new(raw: impl Into<String>) -> Result<Self, PatternError> {
78 let s: String = raw.into();
79 validate_pattern(&s)?;
80 Ok(Self(s))
81 }
82
83 pub fn as_str(&self) -> &str {
85 &self.0
86 }
87}
88
89impl TryFrom<&str> for HostPattern {
90 type Error = PatternError;
91 fn try_from(value: &str) -> Result<Self, Self::Error> {
92 Self::new(value)
93 }
94}
95
96impl TryFrom<String> for HostPattern {
97 type Error = PatternError;
98 fn try_from(value: String) -> Result<Self, Self::Error> {
99 Self::new(value)
100 }
101}
102
103impl<'de> serde::Deserialize<'de> for HostPattern {
104 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
105 let raw = String::deserialize(d)?;
106 Self::new(raw).map_err(serde::de::Error::custom)
107 }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
117#[non_exhaustive]
118pub struct UserExtensionHost {
119 pub host: HostPattern,
123 #[serde(default)]
125 pub note: Option<String>,
126}
127
128impl UserExtensionHost {
129 #[cfg(test)]
131 #[allow(clippy::expect_used)]
132 pub(crate) fn for_test(host: &str) -> Self {
133 Self {
134 host: HostPattern::new(host).expect("test host must be valid"),
135 note: None,
136 }
137 }
138}
139
140#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)]
145#[non_exhaustive]
146pub enum PatternError {
147 #[error("empty pattern")]
149 Empty,
150 #[error("pattern has leading or trailing whitespace")]
152 Whitespace,
153 #[error("bare wildcard `*` is not allowed")]
155 BareWildcard,
156 #[error("wildcard `*` is only allowed as the first character followed by `.`")]
159 MisplacedWildcard,
160 #[error("multi-segment globs are not allowed; use a single `*.<suffix>`")]
162 MultiSegmentGlob,
163 #[error("nothing after wildcard prefix `*.`")]
165 EmptySuffix,
166 #[error("host must contain at least one `.`")]
168 NoDot,
169 #[error("empty label (consecutive `.` or leading/trailing `.`)")]
171 EmptyLabel,
172 #[error("label `{label}` starts or ends with `-`")]
174 LabelHyphenBorder {
175 label: String,
177 },
178 #[error("label `{label}` contains a non-host character (allowed: A-Z a-z 0-9 - .)")]
180 BadChar {
181 label: String,
183 },
184}
185
186#[derive(Debug, thiserror::Error)]
188#[non_exhaustive]
189pub enum UserExtensionError {
190 #[error("io reading {path}: {source}")]
193 Io {
194 path: String,
196 #[source]
198 source: std::io::Error,
199 },
200 #[error("toml parse of {path}: {source}")]
202 Parse {
203 path: String,
205 #[source]
207 source: toml::de::Error,
208 },
209 #[error("invalid host pattern(s) in {path}: {issues:?}")]
214 InvalidPatterns {
215 path: String,
217 issues: Vec<InvalidPatternIssue>,
219 },
220}
221
222#[derive(Debug, Clone, PartialEq, Eq)]
224#[non_exhaustive]
225pub struct InvalidPatternIssue {
226 pub pattern: String,
228 pub kind: PatternError,
230}
231
232impl std::fmt::Display for InvalidPatternIssue {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 write!(f, "`{}`: {}", self.pattern, self.kind)
235 }
236}
237
238pub fn load(config_path: &Utf8Path) -> Result<Vec<UserExtensionHost>, UserExtensionError> {
256 let text = match std::fs::read_to_string(config_path.as_std_path()) {
257 Ok(s) => s,
258 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(Vec::new()),
259 Err(e) => {
260 return Err(UserExtensionError::Io {
261 path: config_path.to_string(),
262 source: e,
263 })
264 }
265 };
266 parse_str(&text, config_path)
267}
268
269#[derive(Debug, Default, Deserialize)]
275struct RawConfig {
276 #[serde(default)]
277 network: Option<RawNetwork>,
278 #[serde(flatten)]
279 _other: serde::de::IgnoredAny,
280}
281
282#[derive(Debug, Default, Deserialize)]
283struct RawNetwork {
284 #[serde(default)]
285 additional_hosts: Vec<RawHost>,
286 #[serde(flatten)]
287 _other: serde::de::IgnoredAny,
288}
289
290#[derive(Debug, Deserialize)]
294#[serde(deny_unknown_fields)]
295struct RawHost {
296 host: String,
297 #[serde(default)]
298 note: Option<String>,
299}
300
301fn parse_str(
304 text: &str,
305 config_path: &Utf8Path,
306) -> Result<Vec<UserExtensionHost>, UserExtensionError> {
307 let raw: RawConfig = toml::from_str(text).map_err(|e| UserExtensionError::Parse {
308 path: config_path.to_string(),
309 source: e,
310 })?;
311 let raw_hosts = raw.network.unwrap_or_default().additional_hosts;
312
313 let mut issues = Vec::new();
317 let mut validated = Vec::with_capacity(raw_hosts.len());
318 for raw_host in raw_hosts {
319 match HostPattern::new(raw_host.host.clone()) {
320 Ok(host) => validated.push(UserExtensionHost {
321 host,
322 note: raw_host.note,
323 }),
324 Err(kind) => issues.push(InvalidPatternIssue {
325 pattern: raw_host.host,
326 kind,
327 }),
328 }
329 }
330 if !issues.is_empty() {
331 return Err(UserExtensionError::InvalidPatterns {
332 path: config_path.to_string(),
333 issues,
334 });
335 }
336 Ok(validated)
337}
338
339pub fn validate_pattern(pattern: &str) -> Result<(), PatternError> {
354 if pattern.is_empty() {
355 return Err(PatternError::Empty);
356 }
357 if pattern.trim() != pattern {
358 return Err(PatternError::Whitespace);
359 }
360 if pattern == "*" {
361 return Err(PatternError::BareWildcard);
362 }
363 let body = match pattern.strip_prefix("*.") {
364 Some(rest) => {
365 if rest.contains('*') {
368 return Err(PatternError::MultiSegmentGlob);
369 }
370 rest
371 }
372 None if pattern.contains('*') => {
373 return Err(PatternError::MisplacedWildcard);
376 }
377 None => pattern,
378 };
379 if body.is_empty() {
380 return Err(PatternError::EmptySuffix);
381 }
382 validate_fqdn(body)
383}
384
385fn validate_fqdn(body: &str) -> Result<(), PatternError> {
386 if !body.contains('.') {
387 return Err(PatternError::NoDot);
388 }
389 for label in body.split('.') {
390 if label.is_empty() {
391 return Err(PatternError::EmptyLabel);
392 }
393 if label.starts_with('-') || label.ends_with('-') {
394 return Err(PatternError::LabelHyphenBorder {
395 label: label.to_string(),
396 });
397 }
398 if !label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') {
399 return Err(PatternError::BadChar {
400 label: label.to_string(),
401 });
402 }
403 }
404 Ok(())
405}
406
407pub fn merge_into_allowlists(
421 allowlists: &mut Vec<SourceAllowlist>,
422 user_hosts: &[UserExtensionHost],
423) {
424 if user_hosts.is_empty() {
425 return;
426 }
427 if let Some(oa) = allowlists.iter_mut().find(|a| a.source == "oa-publisher") {
428 for h in user_hosts {
429 let s = h.host.as_str();
430 if !oa.redirect_hosts.iter().any(|p| p == s) {
431 oa.redirect_hosts.push(s.to_string());
432 }
433 }
434 return;
435 }
436 let mut new_patterns: Vec<String> = Vec::with_capacity(user_hosts.len());
437 for h in user_hosts {
438 let s = h.host.as_str().to_string();
439 if !new_patterns.contains(&s) {
440 new_patterns.push(s);
441 }
442 }
443 allowlists.push(SourceAllowlist::new("oa-publisher", new_patterns));
444}
445
446#[cfg(test)]
447#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
448mod tests {
449 use super::*;
450
451 fn p(s: &str) -> &Utf8Path {
452 Utf8Path::new(s)
453 }
454
455 #[test]
458 fn validate_pattern_accepts_literal_fqdn() {
459 assert!(validate_pattern("ruj.uj.edu.pl").is_ok());
460 assert!(validate_pattern("example.org").is_ok());
461 assert!(validate_pattern("a.b.c.d.e").is_ok());
462 }
463
464 #[test]
465 fn validate_pattern_accepts_single_suffix_wildcard() {
466 assert!(validate_pattern("*.uj.edu.pl").is_ok());
467 assert!(validate_pattern("*.aps.org").is_ok());
468 }
469
470 #[test]
471 fn validate_pattern_rejects_empty() {
472 assert_eq!(validate_pattern(""), Err(PatternError::Empty));
473 }
474
475 #[test]
476 fn validate_pattern_rejects_whitespace() {
477 assert_eq!(
478 validate_pattern(" example.org"),
479 Err(PatternError::Whitespace)
480 );
481 assert_eq!(
482 validate_pattern("example.org "),
483 Err(PatternError::Whitespace)
484 );
485 }
486
487 #[test]
488 fn validate_pattern_rejects_bare_wildcard() {
489 assert_eq!(validate_pattern("*"), Err(PatternError::BareWildcard));
490 }
491
492 #[test]
493 fn validate_pattern_rejects_multi_segment_globs() {
494 for bad in ["*.edu.*", "*.ac.*", "*.*", "*.example.*"] {
495 assert_eq!(
496 validate_pattern(bad),
497 Err(PatternError::MultiSegmentGlob),
498 "{bad} should be MultiSegmentGlob"
499 );
500 }
501 }
502
503 #[test]
504 fn validate_pattern_rejects_misplaced_wildcards() {
505 for bad in ["foo.*.org", "f*o.bar", "*foo.bar"] {
506 assert_eq!(
507 validate_pattern(bad),
508 Err(PatternError::MisplacedWildcard),
509 "{bad} should be MisplacedWildcard"
510 );
511 }
512 }
513
514 #[test]
515 fn validate_pattern_rejects_non_host_chars() {
516 for bad in ["user@host.com", "host.com/", "host.com:80", "https://x.y"] {
517 assert!(
518 matches!(
519 validate_pattern(bad),
520 Err(PatternError::BadChar { .. }) | Err(PatternError::EmptyLabel)
521 ),
522 "{bad} should be BadChar or EmptyLabel; got {:?}",
523 validate_pattern(bad)
524 );
525 }
526 }
527
528 #[test]
529 fn validate_pattern_rejects_no_dot() {
530 assert_eq!(validate_pattern("singlelabel"), Err(PatternError::NoDot));
531 }
532
533 #[test]
534 fn validate_pattern_rejects_empty_label_classes() {
535 for bad in [".example.org", "example..org", "example.org."] {
536 assert_eq!(
537 validate_pattern(bad),
538 Err(PatternError::EmptyLabel),
539 "{bad} should be EmptyLabel"
540 );
541 }
542 }
543
544 #[test]
545 fn validate_pattern_rejects_hyphen_bordering_labels() {
546 for (bad, label) in [
547 ("-foo.example.org", "-foo"),
548 ("foo.-example.org", "-example"),
549 ("foo.example-.org", "example-"),
550 ] {
551 assert_eq!(
552 validate_pattern(bad),
553 Err(PatternError::LabelHyphenBorder {
554 label: label.to_string()
555 }),
556 "{bad} should be LabelHyphenBorder({label})"
557 );
558 }
559 }
560
561 #[test]
562 fn validate_pattern_rejects_empty_suffix_after_wildcard() {
563 assert_eq!(validate_pattern("*."), Err(PatternError::EmptySuffix));
564 }
565
566 #[test]
569 fn host_pattern_new_validates() {
570 assert!(HostPattern::new("ruj.uj.edu.pl").is_ok());
571 assert_eq!(HostPattern::new(""), Err(PatternError::Empty));
572 }
573
574 #[test]
575 fn host_pattern_try_from_str_and_string() {
576 let from_str: HostPattern = "*.aps.org".try_into().expect("ok");
577 let from_string: HostPattern = String::from("*.aps.org").try_into().expect("ok");
578 assert_eq!(from_str, from_string);
579 }
580
581 #[test]
582 fn host_pattern_deserialize_validates() {
583 let bad = toml::from_str::<HostPattern>("\"*.edu.*\"");
587 assert!(bad.is_err(), "TOML deserialize MUST validate the pattern");
588 }
589
590 #[test]
593 fn parse_empty_config_returns_no_hosts() {
594 assert_eq!(parse_str("", p("config.toml")).unwrap(), vec![]);
595 }
596
597 #[test]
598 fn parse_config_without_network_section_returns_no_hosts() {
599 let toml = r#"
600 [store]
601 root = "/tmp"
602 "#;
603 assert_eq!(parse_str(toml, p("config.toml")).unwrap(), vec![]);
604 }
605
606 #[test]
607 fn parse_config_with_unknown_network_fields_is_accepted() {
608 let toml = r#"
613 [network]
614 contact_email = "x@y.org"
615 cooldown_ms = 250
616 "#;
617 assert_eq!(parse_str(toml, p("config.toml")).unwrap(), vec![]);
618 }
619
620 #[test]
621 fn parse_rejects_unknown_field_inside_additional_hosts_entry() {
622 let toml = r#"
629 [[network.additional_hosts]]
630 host = "ruj.uj.edu.pl"
631 notez = "typo"
632 "#;
633 let err = parse_str(toml, p("config.toml")).expect_err("typo must fail");
634 assert!(matches!(err, UserExtensionError::Parse { .. }));
635 }
636
637 #[test]
638 fn parse_one_literal_host_with_note() {
639 let toml = r#"
640 [[network.additional_hosts]]
641 host = "ruj.uj.edu.pl"
642 note = "Jagiellonian University Repository"
643 "#;
644 let got = parse_str(toml, p("config.toml")).unwrap();
645 assert_eq!(got.len(), 1);
646 assert_eq!(got[0].host.as_str(), "ruj.uj.edu.pl");
647 assert_eq!(
648 got[0].note.as_deref(),
649 Some("Jagiellonian University Repository")
650 );
651 }
652
653 #[test]
654 fn parse_multiple_hosts_mixed_literal_and_wildcard() {
655 let toml = r#"
656 [[network.additional_hosts]]
657 host = "ruj.uj.edu.pl"
658
659 [[network.additional_hosts]]
660 host = "*.aps.org"
661 note = "user override"
662 "#;
663 let got = parse_str(toml, p("config.toml")).unwrap();
664 assert_eq!(got.len(), 2);
665 assert_eq!(got[0].host.as_str(), "ruj.uj.edu.pl");
666 assert!(got[0].note.is_none());
667 assert_eq!(got[1].host.as_str(), "*.aps.org");
668 assert_eq!(got[1].note.as_deref(), Some("user override"));
669 }
670
671 #[test]
672 fn parse_collects_all_invalid_patterns_not_just_first() {
673 let toml = r#"
675 [[network.additional_hosts]]
676 host = "*.edu.*"
677
678 [[network.additional_hosts]]
679 host = "ok.example.org"
680
681 [[network.additional_hosts]]
682 host = "user@host.com"
683 "#;
684 let err = parse_str(toml, p("/home/u/.config/doiget/config.toml"))
685 .expect_err("invalid patterns must error");
686 match err {
687 UserExtensionError::InvalidPatterns { path, issues } => {
688 assert_eq!(path, "/home/u/.config/doiget/config.toml");
689 assert_eq!(issues.len(), 2, "both bad patterns collected");
690 assert_eq!(issues[0].pattern, "*.edu.*");
691 assert_eq!(issues[0].kind, PatternError::MultiSegmentGlob);
692 assert_eq!(issues[1].pattern, "user@host.com");
693 assert!(matches!(
694 issues[1].kind,
695 PatternError::BadChar { .. } | PatternError::EmptyLabel
696 ));
697 }
698 other => panic!("expected InvalidPatterns, got {other:?}"),
699 }
700 }
701
702 #[test]
703 fn parse_rejects_malformed_toml() {
704 let err = parse_str("[[network.additional_hosts\nhost=\"foo\"", p("config.toml"))
705 .expect_err("malformed toml must error");
706 assert!(matches!(err, UserExtensionError::Parse { .. }));
707 }
708
709 #[test]
712 fn load_returns_empty_when_file_missing() {
713 let td = tempfile::TempDir::new().unwrap();
714 let path = Utf8Path::from_path(td.path()).unwrap().join("missing.toml");
715 let got = load(&path).expect("missing file MUST be Ok(empty)");
716 assert_eq!(got, vec![]);
717 }
718
719 #[test]
720 fn load_reads_real_file() {
721 use std::io::Write;
722 let td = tempfile::TempDir::new().unwrap();
723 let path = Utf8Path::from_path(td.path()).unwrap().join("config.toml");
724 let mut f = std::fs::File::create(path.as_std_path()).unwrap();
725 f.write_all(
726 br#"
727[[network.additional_hosts]]
728host = "ruj.uj.edu.pl"
729note = "Jagiellonian"
730"#,
731 )
732 .unwrap();
733 let got = load(&path).expect("ok");
734 assert_eq!(got.len(), 1);
735 assert_eq!(got[0].host.as_str(), "ruj.uj.edu.pl");
736 }
737
738 #[test]
741 fn merge_appends_to_existing_oa_publisher_entry() {
742 let mut allowlists = vec![
743 SourceAllowlist::new("crossref", vec!["api.crossref.org".into()]),
744 SourceAllowlist::new("oa-publisher", vec!["pmc.ncbi.nlm.nih.gov".into()]),
745 ];
746 let user_hosts = vec![UserExtensionHost::for_test("ruj.uj.edu.pl")];
747 merge_into_allowlists(&mut allowlists, &user_hosts);
748
749 let oa = allowlists
750 .iter()
751 .find(|a| a.source == "oa-publisher")
752 .unwrap();
753 assert_eq!(
754 oa.redirect_hosts,
755 vec![
756 "pmc.ncbi.nlm.nih.gov".to_string(),
757 "ruj.uj.edu.pl".to_string()
758 ]
759 );
760 assert_eq!(allowlists.len(), 2);
761 }
762
763 #[test]
764 fn merge_creates_oa_publisher_entry_if_missing() {
765 let mut allowlists = vec![SourceAllowlist::new(
766 "crossref",
767 vec!["api.crossref.org".into()],
768 )];
769 let user_hosts = vec![UserExtensionHost::for_test("ruj.uj.edu.pl")];
770 merge_into_allowlists(&mut allowlists, &user_hosts);
771 assert_eq!(allowlists.len(), 2);
772 let oa = allowlists
773 .iter()
774 .find(|a| a.source == "oa-publisher")
775 .unwrap();
776 assert_eq!(oa.redirect_hosts, vec!["ruj.uj.edu.pl".to_string()]);
777 }
778
779 #[test]
780 fn merge_is_noop_on_empty_user_hosts() {
781 let mut allowlists = vec![SourceAllowlist::new(
782 "crossref",
783 vec!["api.crossref.org".into()],
784 )];
785 let snapshot: Vec<(String, Vec<String>)> = allowlists
786 .iter()
787 .map(|a| (a.source.clone(), a.redirect_hosts.clone()))
788 .collect();
789 merge_into_allowlists(&mut allowlists, &[]);
790 let after: Vec<(String, Vec<String>)> = allowlists
791 .iter()
792 .map(|a| (a.source.clone(), a.redirect_hosts.clone()))
793 .collect();
794 assert_eq!(snapshot, after);
795 }
796
797 #[test]
798 fn merge_dedupes_against_existing_entries() {
799 let mut allowlists = vec![SourceAllowlist::new(
801 "oa-publisher",
802 vec!["ruj.uj.edu.pl".into()],
803 )];
804 let user_hosts = vec![
805 UserExtensionHost::for_test("ruj.uj.edu.pl"),
806 UserExtensionHost::for_test("*.uj.edu.pl"),
807 UserExtensionHost::for_test("*.uj.edu.pl"),
808 ];
809 merge_into_allowlists(&mut allowlists, &user_hosts);
810 let oa = allowlists
811 .iter()
812 .find(|a| a.source == "oa-publisher")
813 .unwrap();
814 assert_eq!(
815 oa.redirect_hosts,
816 vec!["ruj.uj.edu.pl".to_string(), "*.uj.edu.pl".to_string()]
817 );
818 }
819
820 #[test]
821 fn merge_dedupes_when_creating_new_entry() {
822 let mut allowlists = Vec::new();
823 let user_hosts = vec![
824 UserExtensionHost::for_test("ruj.uj.edu.pl"),
825 UserExtensionHost::for_test("ruj.uj.edu.pl"),
826 ];
827 merge_into_allowlists(&mut allowlists, &user_hosts);
828 assert_eq!(allowlists.len(), 1);
829 assert_eq!(allowlists[0].redirect_hosts, vec!["ruj.uj.edu.pl"]);
830 }
831
832 #[test]
833 fn merged_pattern_is_matched_by_source_allowlist() {
834 let parsed = parse_str(
835 r#"
836[[network.additional_hosts]]
837host = "*.uj.edu.pl"
838"#,
839 p("config.toml"),
840 )
841 .unwrap();
842 let mut allowlists = vec![SourceAllowlist::new("oa-publisher", vec![])];
843 merge_into_allowlists(&mut allowlists, &parsed);
844 let oa = allowlists
845 .iter()
846 .find(|a| a.source == "oa-publisher")
847 .unwrap();
848 assert!(oa.matches("ruj.uj.edu.pl"));
849 assert!(oa.matches("alpha.uj.edu.pl"));
850 assert!(!oa.matches("ruj.uj.edu.ru"));
851 }
852}