BitSet

前幾天幹了一件比較無聊的事兒——抄了一遍C++ STL bitset的源代碼,把不懂的宏定義去掉了,發現(暫時)還能用,嘿嘿。ios

#ifndef BITSET_H
#define BITSET_H
#include <string>
#include <stdexcept>
#include <iostream>

template <size_t Bits>
class BitSet {
public:
	typedef bool element_type;
	typedef unsigned int Type;

	class Reference {
		friend class BitSet<Bits>;
	public:
		~Reference() {}

		Reference& operator=(bool val) {
			p->set(pos, val);
			return *this;
		}

		Reference& operator=(const Reference& bitRef) {
			p->set(pos, bool(bitRef));
			return *this;
		}

		Reference& flip() {
			p->flip(pos);
			return *this;
		}

		bool operator~() const {
			return !p->test(pos);
		}

		operator bool() const {
			return p->test(pos);
		}

	private:
		Reference() : p(0), pos(0) {}

		Reference(BitSet<Bits>& bitSet, size_t pos) : p(&bitSet), pos(pos) {}

		BitSet<Bits> *p;
		size_t pos;
	};

	bool at(size_t pos) const {
		return test(pos);
	}

	Reference at(size_t pos) {
		return Reference(*this, pos);
	}

	bool operator[](size_t pos) const {
		return test(pos);
	}

	Reference operator[](size_t pos) {
		return Reference(*this, pos);
	}

	BitSet() {
		tidy();
	}

	BitSet(unsigned long long val) {
		tidy();
		for (size_t pos = 0; val != 0 && pos < Bits; val >>= 1, ++pos) {
			if (val & 1)
				set(pos);
		}
	}

	BitSet(const std::string& str,
		std::string::size_type pos = 0,
		std::string::size_type count = std::string::npos,
		char e0 = '0',
		char e1 = '1') {
		construct(str, pos, count, e0, e1);
	}

	BitSet(const char *ptr,
		std::string::size_type count = std::string::npos,
		char e0 = '0',
		char e1 = '1') {
		construct(count == std::string::npos ?
			std::string(ptr) : std::string(ptr, count), 0, count, e0, e1);
	}

	void construct(const std::string& str,
		std::string::size_type pos,
		std::string::size_type count,
		char e0,
		char e1) {
		std::string::size_type num;
		if (str.size() < pos)
			xran();
		if (str.size() - pos < count)
			count = str.size() - pos;
		if (Bits < count)
			count = Bits;
		tidy();

		for (pos += count, num = 0; num < count; ++num) {
			if (str[--pos] == e1)
				set(num);
			else if (str[pos] != e0)
				xinv();
		}
	}

	BitSet<Bits>& operator&=(const BitSet<Bits>& right) {
		for (ptrdiff_t pos = Words; 0 <= pos; --pos) {
			array[pos] &= right.getWord(pos);
		}
		return *this;
	}

	BitSet<Bits>& operator|=(const BitSet<Bits>& right) {
		for (ptrdiff_t pos = Words; 0 <= pos; --pos) {
			array[pos] |= right.getWord(pos);
		}
		return *this;
	}

	BitSet<Bits>& operator^=(const BitSet<Bits>& right) {
		for (ptrdiff_t pos = Words; 0 <= pos; --pos) {
			array[pos] ^= right.getWord(pos);
		}
		return *this;
	}

	BitSet<Bits>& operator<<=(size_t pos) {
		const ptrdiff_t wordShift = (ptrdiff_t)(pos / BitsPerWord);
		if (wordShift != 0) {
			for (ptrdiff_t wpos = Words; 0 <= wpos; --wpos)
				array[wpos] = (wordShift <= wpos ? 
					array[wpos - wordShift] : 0);
		}
		if ((pos %= BitsPerWord) != 0) {
			// 0 < pos < BitsPerWord, shift by bits
			for (ptrdiff_t wpos = Words; 0 < wpos; --wpos)
				array[wpos] = (Type)((array[wpos] << pos) |
					(array[wpos - 1]) >> (BitsPerWord - pos));
			array[0] <<= pos;
		}
		trim();
		return *this;
	}

	BitSet<Bits>& operator>>=(size_t pos) {
		const ptrdiff_t wordShift = (ptrdiff_t)(pos / BitsPerWord);
		if (wordShift != 0) {
			for (ptrdiff_t wpos = 0; wpos <= Words; ++wpos)
				array[wpos] =
					(wordShift <= Words - wpos ? array[wpos + wordShift] : 0);
		}
		if ((pos %= BitsPerWord) != 0) {
			for (ptrdiff_t wpos = 0; wpos < Words; ++wpos)
				array[wpos] = (Type)((array[wpos] >> pos) |
					(array[wpos + 1] << (BitsPerWord - pos)));
			array[Words] >>= pos;
		}
		return *this;
	}

	BitSet<Bits>& set() {
		tidy((Type)~0);
		return *this;
	}

	BitSet<Bits>& set(size_t pos, bool val = true) {
		if (Bits <= pos)
			xran();
		if (val)
			array[pos / BitsPerWord] |= (Type)1 << pos % BitsPerWord;
		else
			array[pos / BitsPerWord] &= ~((Type)1 << pos % BitsPerWord);
		return *this;
	}

	BitSet<Bits>& reset() {
		tidy();
		return *this;
	}

	BitSet<Bits>& reset(size_t pos) {
		return set(pos, false);
	}

	BitSet<Bits>& operator~() const {
		return BitSet<Bits>(*this).flip();
	}

	BitSet<Bits>& flip() {
		for (ptrdiff_t pos = Words; 0 <= pos; --pos)
			array[pos] = (Type)~array[pos];
		trim();
		return *this;
	}

	BitSet<Bits>& flip(size_t pos) {
		if (Bits <= pos)
			xran();
		array[pos / BitsPerWord] ^= (Type)1 << pos % BitsPerWord;
		return *this;
	}

	unsigned long to_ulong() const {
		unsigned long long val = to_ullong();
		unsigned long ans = (unsigned long)val;
		if (ans != val)
			xofllo();
		return ans;
	}

	unsigned long long to_ullong() const {
		ptrdiff_t pos = Words;
		for (; (ptrdiff_t)(sizeof(unsigned long long) / sizeof(Type)) <= pos; --pos) {
			if (array[pos] != 0)
				xofllo();
		}
		unsigned long long val = array[pos];
		while (0 <= --pos)
			val = ((val << (BitsPerWord - 1)) << 1) | array[pos];
		return val;
	}

	std::string to_string(char e0 = '0', char e1 = '1') const {
		std::string str;
		str.reserve(Bits);
		std::string::size_type pos = Bits;
		while (0 < pos) {
			if (test(--pos))
				str += e1;
			else
				str += e0;
		}
		return str;
	}

	size_t count() const {
		const char *const BitsPerByte =
			"\0\1\1\2\1\2\2\3\1\2\2\3\2\3\3\4"
			"\1\2\2\3\2\3\3\4\2\3\3\4\3\4\4\5"
			"\1\2\2\3\2\3\3\4\2\3\3\4\3\4\4\5"
			"\2\3\3\4\3\4\4\5\3\4\4\5\4\5\5\6"
			"\1\2\2\3\2\3\3\4\2\3\3\4\3\4\4\5"
			"\2\3\3\4\3\4\4\5\3\4\4\5\4\5\5\6"
			"\2\3\3\4\3\4\4\5\3\4\4\5\4\5\5\6"
			"\3\4\4\5\4\5\5\6\4\5\5\6\5\6\6\7"
			"\1\2\2\3\2\3\3\4\2\3\3\4\3\4\4\5"
			"\2\3\3\4\3\4\4\5\3\4\4\5\4\5\5\6"
			"\2\3\3\4\3\4\4\5\3\4\4\5\4\5\5\6"
			"\3\4\4\5\4\5\5\6\4\5\5\6\5\6\6\7"
			"\2\3\3\4\3\4\4\5\3\4\4\5\4\5\5\6"
			"\3\4\4\5\4\5\5\6\4\5\5\6\5\6\6\7"
			"\3\4\4\5\4\5\5\6\4\5\5\6\5\6\6\7"
			"\4\5\5\6\5\6\6\7\5\6\6\7\6\7\7\x8";
		const unsigned char *ptr = (const unsigned char *)(const void *)array;
		const unsigned char *const end = ptr + sizeof(array);
		size_t val = 0;
		for (; ptr != end; ++ptr)
			val += BitsPerByte[*ptr];
		return val;
	}

	size_t size() const {
		return Bits;
	}

	bool operator==(const BitSet<Bits>& right) const {
		for (ptrdiff_t pos = Words; 0 <= pos; --pos) {
			if (array[pos] != right.getWord(pos))
				return false;
		}
		return true;
	}

	bool operator!=(const BitSet<Bits>& right) const {
		return !(*this == right);
	}

	bool test(size_t pos) const {
		if (Bits < pos)
			xran();
		return (array[pos / BitsPerWord] & ((Type)1 << pos % BitsPerWord)) != 0;
	}

	bool any() const {
		for (ptrdiff_t pos = Words; 0 <= pos; --pos) {
			if (array[pos] != 0)
				return true;
		}
		return false;
	}

	bool none() const {
		return !any();
	}

	bool all() const {
		return count() == size();
	}

	BitSet<Bits> operator<<(size_t pos) const {
		return BitSet<Bits>(*this) <<= pos;
	}

	BitSet<Bits> operator>>(size_t pos) const {
		return BitSet<Bits>(*this) >>= pos;
	}

	Type getWord(size_t pos) const {
		return array[pos];
	}

private:
	static const ptrdiff_t BitsPerWord = (ptrdiff_t)(CHAR_BIT * sizeof(Type));
	static const ptrdiff_t Words =
		(ptrdiff_t)(Bits == 0 ? 0 : (Bits - 1) / BitsPerWord);

	void tidy(Type val = 0) {
		for (ptrdiff_t pos = Words; 0 <= pos; --pos)
			array[pos] = val;
		if (val != 0)
			trim();
	}

	void trim() {
		if (Bits % BitsPerWord != 0)
			array[Words] &= ((Type)1 << Bits % BitsPerWord) - 1;
	}

	void xinv() const {
		throw std::invalid_argument("invalid BitSet<N> char");
	}

	void xofllo() const {
		throw std::overflow_error("BitSet<N> overflow");
	}

	void xran() const {
		throw std::out_of_range("invalid BitSet<N> position");
	}

	Type array[Words + 1];
};

template<size_t Bits>
BitSet<Bits> operator&(const BitSet<Bits>& left, const BitSet<Bits>& right) {
	BitSet<Bits> ans = left;
	return ans &= right;
}

template<size_t Bits>
BitSet<Bits> operator|(const BitSet<Bits>& left, const BitSet<Bits>& right) {
	BitSet<Bits> ans = left;
	return ans |= right;
}

template<size_t Bits>
BitSet<Bits> operator^(const BitSet<Bits>& left, const BitSet<Bits>& right) {
	BitSet<Bits> ans = left;
	return ans ^= right;
}

template<size_t Bits>
std::ostream& operator<<(std::ostream& os, const BitSet<Bits>& right) {
	return os << right.to_string();
}

template<size_t Bits>
std::istream& operator>>(std::istream& is, BitSet<Bits>& right) {
	char e1 = '1';
	char e0 = '0';
	std::ios_base::iostate state = std::ios_base::goodbit;
	bool changed = false;
	std::string str;
	const std::istream::sentry ok(is);

	if (ok) {
		try {
			int meta = is.rdbuf()->sgetc();
			for (size_t count = right.size();
			     0 < count;
			     meta = is.rdbuf()->snextc(), --count) {
				char c;
				if (meta == EOF) {
					state |= std::ios_base::eofbit;
					break;
				}
				else if ((c = (char)meta) != e0 && c != e1)
					break;
				else if (str.max_size() <= str.size()) {
					state |= std::ios_base::failbit;
					break;
				}
				else {
					if (c == e1)
						str.append(1, '1');
					else
						str.append(1, '0');
					changed = true;
				}
			}
		}
		catch (...) {
			is.setstate(std::ios_base::badbit, true);
		}
	}
	if (!changed)
		state |= std::ios_base::failbit;
	is.setstate(state);
	right = BitSet<Bits>(str);
	return is;
}

#endif
相關文章
相關標籤/搜索