use crate::collections::HashMap;
use std::borrow::Borrow;
use std::cmp::Ordering;
use std::hash::Hash;
use std::iter::FromIterator;
use std::ops::{BitAnd, BitOr, BitXor, Sub};
pub struct HashSet<T>
where
T: Hash + Eq,
{
hash_map: HashMap<T, ()>,
}
impl<T> HashSet<T>
where
T: Hash + Eq,
{
pub fn new() -> Self {
Default::default()
}
pub fn len(&self) -> usize {
self.hash_map.len()
}
pub fn is_empty(&self) -> bool {
self.hash_map.is_empty()
}
pub fn insert(&mut self, value: T) -> bool {
self.hash_map.insert(value, ()).is_none()
}
pub fn contains<Q>(&self, value: &Q) -> bool
where
T: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.hash_map.get(value).is_some()
}
pub fn remove<Q>(&mut self, value: &Q) -> bool
where
T: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.hash_map.remove(value).is_some()
}
pub fn iter(&self) -> impl Iterator<Item = &T> {
self.hash_map.iter().map(|(k, _)| k)
}
pub fn union<'a>(&'a self, other: &'a HashSet<T>) -> impl Iterator<Item = &T> {
self.iter().chain(other.difference(self))
}
pub fn difference<'a>(&'a self, other: &'a HashSet<T>) -> impl Iterator<Item = &T> {
self.iter().filter(move |item| !other.contains(item))
}
pub fn symmetric_difference<'a>(&'a self, other: &'a HashSet<T>) -> impl Iterator<Item = &T> {
self.difference(other).chain(other.difference(self))
}
pub fn intersection<'a>(&'a self, other: &'a HashSet<T>) -> impl Iterator<Item = &T> {
self.iter().filter(move |item| other.contains(item))
}
pub fn is_disjoint(&self, other: &HashSet<T>) -> bool {
self.intersection(other).count() == 0
}
pub fn is_subset(&self, other: &HashSet<T>) -> bool {
if self.len() > other.len() {
return false;
}
self.iter().all(|item| other.contains(&item))
}
pub fn is_superset(&self, other: &HashSet<T>) -> bool {
other.is_subset(self)
}
}
impl<T> Default for HashSet<T>
where
T: Hash + Eq,
{
fn default() -> Self {
Self {
hash_map: HashMap::new(),
}
}
}
impl<T> PartialEq for HashSet<T>
where
T: Hash + Eq,
{
fn eq(&self, other: &HashSet<T>) -> bool {
if self.len() != other.len() {
return false;
}
self.iter().all(|item| other.contains(&item))
}
}
impl<T> Eq for HashSet<T> where T: Hash + Eq {}
impl<T> PartialOrd for HashSet<T>
where
T: Hash + Eq,
{
fn partial_cmp(&self, other: &HashSet<T>) -> Option<Ordering> {
let is_subset = self.is_subset(other);
let same_size = self.len() == other.len();
match (is_subset, same_size) {
(true, true) => Some(Ordering::Equal),
(true, false) => Some(Ordering::Less),
(false, true) => None,
_ => Some(Ordering::Greater).filter(|_| self.is_superset(other)),
}
}
}
impl<T> FromIterator<T> for HashSet<T>
where
T: Hash + Eq,
{
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = T>,
{
let mut s = Self::new();
iter.into_iter().for_each(|i| {
s.insert(i);
});
s
}
}
impl<'a, 'b, T> BitOr<&'b HashSet<T>> for &'a HashSet<T>
where
T: Hash + Eq + Clone,
{
type Output = HashSet<T>;
fn bitor(self, rhs: &'b HashSet<T>) -> Self::Output {
self.union(&rhs).cloned().collect()
}
}
impl<'a, 'b, T> Sub<&'b HashSet<T>> for &'a HashSet<T>
where
T: Hash + Eq + Clone,
{
type Output = HashSet<T>;
fn sub(self, rhs: &'b HashSet<T>) -> Self::Output {
self.difference(&rhs).cloned().collect()
}
}
impl<'a, 'b, T> BitXor<&'b HashSet<T>> for &'a HashSet<T>
where
T: Hash + Eq + Clone,
{
type Output = HashSet<T>;
fn bitxor(self, rhs: &'b HashSet<T>) -> Self::Output {
self.symmetric_difference(&rhs).cloned().collect()
}
}
impl<'a, 'b, T> BitAnd<&'b HashSet<T>> for &'a HashSet<T>
where
T: Hash + Eq + Clone,
{
type Output = HashSet<T>;
fn bitand(self, rhs: &'b HashSet<T>) -> Self::Output {
self.intersection(&rhs).cloned().collect()
}
}
#[cfg(test)]
mod basics {
use super::*;
#[test]
fn basic() {
let s: HashSet<String> = HashSet::new();
assert_eq!(s.len(), 0);
assert!(s.is_empty());
}
#[test]
fn insert() {
let mut s = HashSet::new();
let ok = s.insert("cat");
assert!(ok);
assert_eq!(s.len(), 1);
let ok = s.insert("dog");
assert!(ok);
assert_eq!(s.len(), 2);
let ok = s.insert("dog");
assert_eq!(
ok, false,
"Attempting to insert present value returns false"
);
assert_eq!(s.len(), 2, "Certain value can only be inserted to set once");
}
#[test]
fn contains() {
let mut s1: HashSet<&str> = HashSet::new();
s1.insert("cat");
assert_eq!(
s1.contains("cat"),
true,
"contains() returns true for present value"
);
assert_eq!(
s1.contains("dog"),
false,
"contains() returns false for absent value"
);
let mut s2: HashSet<String> = HashSet::new();
s2.insert("cat".to_string());
assert_eq!(
s2.contains(&"cat".to_string()),
true,
"Can query with String"
);
assert_eq!(s2.contains("cat"), true, "Can query with &str");
}
#[test]
fn remove() {
let mut s1: HashSet<&str> = HashSet::new();
s1.insert("cat");
assert!(s1.contains("cat"), "'cat' exists before remove()");
let ok = s1.remove("cat");
assert_eq!(ok, true, "Successful removal returns true");
assert!(!s1.contains("cat"), "'cat' is gone after remove()");
let ok = s1.remove("elephant");
assert_eq!(
ok, false,
"Trying to remove non-existing value returns false"
);
let mut s2: HashSet<String> = HashSet::new();
s2.insert("cat".to_string());
s2.insert("dog".to_string());
assert!(s2.remove(&"cat".to_string()), "Can remove with String");
assert!(
!s2.contains("cat"),
"Successfully removed value with String"
);
assert!(s2.remove("dog"), "Can remove with &str");
assert!(!s2.contains("dog"), "Successfully removed value with &str");
}
#[test]
fn from_iter() {
let s1: HashSet<_> = ["cat", "dog", "rat"].iter().cloned().collect();
assert!(s1.contains("cat"));
assert!(s1.contains("dog"));
assert!(s1.contains("rat"));
assert_eq!(s1.len(), 3);
}
}
#[cfg(test)]
mod set_relations {
use super::*;
#[test]
fn union() {
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
let union = s1.union(&s2);
assert_eq!(union.count(), 0, "∅ ∪ ∅ = ∅");
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
let union: HashSet<_> = s1.union(&s2).cloned().collect();
let expect: HashSet<&str> = ["cat"].iter().cloned().collect();
assert!(union == expect);
let s1: HashSet<&str> = ["cat"].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
let union: HashSet<_> = s1.union(&s2).cloned().collect();
let expect: HashSet<&str> = ["cat"].iter().cloned().collect();
assert!(union == expect);
let s1: HashSet<_> = ["cat", "dog"].iter().cloned().collect();
let s2: HashSet<_> = ["cat", "rat"].iter().cloned().collect();
let union: HashSet<_> = s1.union(&s2).cloned().collect();
let expect: HashSet<&str> = ["cat", "dog", "rat"].iter().cloned().collect();
assert!(union == expect);
}
#[test]
fn intersection() {
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
let intersection = s1.intersection(&s2);
assert_eq!(intersection.count(), 0, "∅ ∩ ∅ = ∅");
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
let intersection = s1.intersection(&s2);
assert_eq!(intersection.count(), 0);
let s1: HashSet<&str> = ["cat"].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
let intersection = s1.intersection(&s2);
assert_eq!(intersection.count(), 0);
let s1: HashSet<_> = ["cat", "dog"].iter().cloned().collect();
let s2: HashSet<_> = ["cat", "rat"].iter().cloned().collect();
let intersection: HashSet<_> = s1.intersection(&s2).cloned().collect();
let expect: HashSet<&str> = ["cat"].iter().cloned().collect();
assert!(intersection == expect);
}
#[test]
fn difference() {
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
let difference = s1.difference(&s2);
assert_eq!(difference.count(), 0, r"∅ \ ∅ = ∅");
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
let difference = s1.difference(&s2);
assert_eq!(difference.count(), 0);
let s1: HashSet<&str> = ["cat"].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
let difference: HashSet<_> = s1.difference(&s2).cloned().collect();
let expect: HashSet<&str> = ["cat"].iter().cloned().collect();
assert!(difference == expect);
let s1: HashSet<_> = ["cat", "dog"].iter().cloned().collect();
let s2: HashSet<_> = ["cat", "rat"].iter().cloned().collect();
let difference: HashSet<_> = s1.difference(&s2).cloned().collect();
let expect: HashSet<&str> = ["dog"].iter().cloned().collect();
assert!(difference == expect);
}
#[test]
fn symmetric_difference() {
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
let symmetric_difference = s1.symmetric_difference(&s2);
assert_eq!(symmetric_difference.count(), 0, "∅ △ ∅ = ∅");
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
let symmetric_difference: HashSet<_> = s1.symmetric_difference(&s2).cloned().collect();
let expect: HashSet<&str> = ["cat"].iter().cloned().collect();
assert!(symmetric_difference == expect);
let s1: HashSet<&str> = ["cat"].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
let symmetric_difference: HashSet<_> = s1.symmetric_difference(&s2).cloned().collect();
let expect: HashSet<&str> = ["cat"].iter().cloned().collect();
assert!(symmetric_difference == expect);
let s1: HashSet<_> = ["cat", "dog"].iter().cloned().collect();
let s2: HashSet<_> = ["cat", "rat"].iter().cloned().collect();
let symmetric_difference: HashSet<_> = s1.symmetric_difference(&s2).cloned().collect();
let expect: HashSet<&str> = ["dog", "rat"].iter().cloned().collect();
assert!(symmetric_difference == expect);
}
#[test]
fn is_disjoint() {
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
assert!(s1.is_disjoint(&s2), "∅, ∅ are disjoint");
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
assert!(s1.is_disjoint(&s2), "{}", "∅, {cat} are disjoint");
assert!(s2.is_disjoint(&s1), "{}", "∅, {cat} are disjoint");
let s1: HashSet<&str> = ["rat"].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
assert!(s1.is_disjoint(&s2));
let s1: HashSet<&str> = ["cat"].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
assert_eq!(s1.is_disjoint(&s2), false);
assert_eq!(s2.is_disjoint(&s1), false);
}
#[test]
fn is_subset() {
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
assert!(s1.is_subset(&s2), "∅ ⊆ ∅");
assert!(s2.is_subset(&s1), "∅ ⊆ ∅");
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
assert!(s1.is_subset(&s2), "∀𝑨: ∅ ⊆ 𝑨");
let s1: HashSet<&str> = ["cat"].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
assert_eq!(s1.is_subset(&s2), false, "∀𝑨, 𝑨 ≠ ∅: 𝑨 ⊈ ∅");
let s1: HashSet<&str> = ["cat"].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
assert!(s1.is_subset(&s2));
let s1: HashSet<&str> = ["cat"].iter().cloned().collect();
let s2: HashSet<&str> = ["cat", "rat"].iter().cloned().collect();
assert!(s1.is_subset(&s2));
let s1: HashSet<&str> = ["cat", "rat"].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
assert_eq!(s1.is_subset(&s2), false);
}
#[test]
fn is_superset() {
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
assert!(s1.is_superset(&s2), "∅ ⊇ ∅");
assert!(s2.is_superset(&s1), "∅ ⊇ ∅");
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
assert_eq!(s1.is_superset(&s2), false, "∀𝑨, 𝑨 ≠ ∅: ∅ ⊉ 𝑨");
let s1: HashSet<&str> = ["cat"].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
assert_eq!(s1.is_superset(&s2), true, "∀𝑨: 𝑨 ⊇ ∅");
let s1: HashSet<&str> = ["cat"].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
assert!(s1.is_superset(&s2));
let s1: HashSet<&str> = ["cat"].iter().cloned().collect();
let s2: HashSet<&str> = ["cat", "rat"].iter().cloned().collect();
assert_eq!(s1.is_superset(&s2), false);
let s1: HashSet<&str> = ["cat", "rat"].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
assert!(s1.is_superset(&s2));
}
}
#[cfg(test)]
mod logical_ops {
use super::*;
#[test]
fn bitor() {
let s1: HashSet<_> = ["cat", "dog"].iter().cloned().collect();
let s2: HashSet<_> = ["cat", "rat"].iter().cloned().collect();
let union = &s1 | &s2;
let expect: HashSet<&str> = ["cat", "dog", "rat"].iter().cloned().collect();
assert!(union == expect);
assert_eq!(s1.len(), 2, "s1 is still available");
assert_eq!(s2.len(), 2, "s2 is still available");
}
#[test]
fn bitand() {
let s1: HashSet<_> = ["cat", "dog"].iter().cloned().collect();
let s2: HashSet<_> = ["cat", "rat"].iter().cloned().collect();
let intersection: HashSet<_> = &s1 & &s2;
let expect: HashSet<&str> = ["cat"].iter().cloned().collect();
assert!(intersection == expect);
}
#[test]
fn sub() {
let s1: HashSet<_> = ["cat", "dog"].iter().cloned().collect();
let s2: HashSet<_> = ["cat", "rat"].iter().cloned().collect();
let difference = &s1 - &s2;
let expect: HashSet<&str> = ["dog"].iter().cloned().collect();
assert!(difference == expect);
}
#[test]
fn bitxor() {
let s1: HashSet<_> = ["cat", "dog"].iter().cloned().collect();
let s2: HashSet<_> = ["cat", "rat"].iter().cloned().collect();
let symmetric_difference: HashSet<_> = &s1 ^ &s2;
let expect: HashSet<&str> = ["dog", "rat"].iter().cloned().collect();
assert!(symmetric_difference == expect);
}
}
#[cfg(test)]
mod cmp_ops {
use super::*;
#[test]
fn eq() {
let set: HashSet<_> = ["cat", "dog", "rat"].iter().cloned().collect();
let identical: HashSet<_> = ["cat", "dog", "rat"].iter().cloned().collect();
assert!(set == identical, "sets of identical elements are equal");
let reordered: HashSet<_> = ["rat", "cat", "dog"].iter().cloned().collect();
assert!(set == reordered, "order of elements doesn't matter");
let different: HashSet<_> = ["cat", "dog", "elephant"].iter().cloned().collect();
assert!(set != different);
let superset: HashSet<_> = ["cat", "dog", "rat", "elephant"].iter().cloned().collect();
assert!(set != superset);
let subset: HashSet<_> = ["cat"].iter().cloned().collect();
assert!(set != subset);
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
assert!(s1 == s2, "∅ = ∅");
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
assert_eq!(s1 != s2, true);
let s1: HashSet<&str> = ["cat"].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
assert_eq!(s1 != s2, true)
}
#[test]
fn partial_cmp() {
let set: HashSet<_> = ["cat", "dog", "rat"].iter().cloned().collect();
let identical: HashSet<_> = ["cat", "dog", "rat"].iter().cloned().collect();
assert_eq!(set.partial_cmp(&identical), Some(Ordering::Equal));
assert_eq!(&set > &identical, false);
assert_eq!(&set >= &identical, true);
assert_eq!(&set < &identical, false);
assert_eq!(&set <= &identical, true);
assert_eq!(&set == &identical, true);
let different: HashSet<_> = ["cat", "dog", "elephant"].iter().cloned().collect();
assert_eq!(set.partial_cmp(&different), None);
assert_eq!(&set > &different, false);
assert_eq!(&set >= &different, false);
assert_eq!(&set < &different, false);
assert_eq!(&set <= &different, false);
assert_eq!(&set == &different, false);
let superset: HashSet<_> = ["cat", "dog", "rat", "elephant"].iter().cloned().collect();
assert_eq!(set.partial_cmp(&superset), Some(Ordering::Less));
assert_eq!(&set > &superset, false);
assert_eq!(&set >= &superset, false);
assert_eq!(&set < &superset, true);
assert_eq!(&set <= &superset, true);
assert_eq!(&set == &superset, false);
let subset: HashSet<_> = ["cat"].iter().cloned().collect();
assert_eq!(set.partial_cmp(&subset), Some(Ordering::Greater));
assert_eq!(&set > &subset, true);
assert_eq!(&set < &subset, false);
assert_eq!(&set == &subset, false);
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = [].iter().cloned().collect();
assert_eq!(s1.partial_cmp(&s2), Some(Ordering::Equal));
assert_eq!(&s1 > &s2, false);
assert_eq!(&s1 >= &s2, true);
assert_eq!(&s1 < &s2, false);
assert_eq!(&s1 <= &s2, true);
assert_eq!(&s1 == &s2, true);
let s1: HashSet<&str> = [].iter().cloned().collect();
let s2: HashSet<&str> = ["cat"].iter().cloned().collect();
assert_eq!(s1.partial_cmp(&s2), Some(Ordering::Less));
assert_eq!(&s1 > &s2, false);
assert_eq!(&s1 >= &s2, false);
assert_eq!(&s1 < &s2, true);
assert_eq!(&s1 <= &s2, true);
assert_eq!(&s1 == &s2, false);
}
}