[NOI2016] 旷野大计算

Posted by Panda2134's Blog on February 19, 2018

UOJ-224

神题。

扑通一声跪下来,千古神犇vfk。


早就听说了这个造计算机题。正好,前几天上洛谷的省选课,这个题目作为提答作业布置了下来。于是我就开始了我的愉快作死之旅啦~

自己xjb乱搞,搞了56pts,发现不会做了XD

于是参考了chrt的题解vfleaking的slide,各种卡,终于拿到了95pts…

未完待续

工具

按照vfk的题解的说法,题目中给出的是神经网络,也就是说,计算中没有“修改变量”的操作,只有“输入→输出”的映射。这是不是有点像函数式编程呢?在试着写了前 3 个点和第 5 个点之后,我发现运算里面行号的处理非常麻烦,如果把行号 hard-code到代码里面,很难看,而且很难调试。不如用函数式的思想,把“节点”稍微包装一下,这样就比较好调试了。还有个问题,题目里面的 90 位小数,怎么实现?难道手写高精度?既然是提交答案题,就不一定要用 C++。python 自带高精度浮点数(decimal模块),而且语法很方便,就用 python 啦。NOI Linux自带python。

如下包装了几个基本命令:

#!/usr/bin/env python3

from decimal import Decimal
import decimal

decimal.getcontext().prec = 90

line = 1

def PutLine(str):
	print(str)
	global line
	line = line + 1
	return line - 1

class Node:
	lineno = 0

	def __init__(self, lineno):
		self.lineno = lineno

	def out(self): # 输出
		Node(PutLine('O {}'.format(self.lineno)))

	def shl(self, d): # 左移d位
		return Node(PutLine('< {} {}'.format(self.lineno, d)))

	def shr(self, d): # 右移d位
		return Node(PutLine('> {} {}'.format(self.lineno, d)))

	def add(self, y): # 加上y节点
		return Node(PutLine('+ {} {}'.format(self.lineno, y.lineno)))

	def opposite(self): # 取相反数
		return Node(PutLine('- {}'.format(self.lineno)))

	def sigmoid(self): # sigmoid函数
		return Node(PutLine('S {}'.format(self.lineno)))

	def offset(self, c): # 偏移常数c
		return Node(PutLine('C {} {:.90f}'.format(self.lineno, c)))

def readin(): # 输入
	return Node(PutLine('I'))

这样调用就很方便了,如第一个点,这么写即可:

readin().add(readin()).shl(1).opposite().out()

从上面一行可以直接地看出用了 $6$ 个基本操作。


以下混用 $="$ 和 $\approx”$.

测试点1-2

基本操作。

测试点3

实现函数:

利用题目中 $\text{Sigmoid}$ 函数的性质:在无穷远处趋近于 $0/1$ 。

容易发现,$s(a«100) = \begin{cases}0 & a<0 \ 1/2 & a = 0 \ 1 & a > 0 \end{cases}$ .

再平移变换即可满足题意。

代码:

def cmp(self):
	return self.shl(100).sigmoid().offset(-0.5).shl(1)

测试点4

这个点就很有意思了。

如果直接用测试点3+乘法,只能得到6分。

满分解法是这样的:

我们考虑构造 $\lvert x \rvert$ : $\lvert x \rvert = x - \min\{2x, 0\}$.

$\min\{2x, 0\}$ 怎么实现?

先构造上面式子的第一行。与正负有关,我们想到了构造第3个点的过程。 我们能不能引入某个量,让它在 $x>0$ 的时候能够去掉 $x$ 的贡献呢?怎么去掉贡献?发现 $s(x)$ 在 $x$ 趋近无穷大的时候趋近一个常数,这就是一种信息的丢失。利用这一点构造式子。如果我们要利用这一点,我们就得有方法把值从 $s(x)$ 还原到 $x$。考虑导数的定义:

于是我们可以利用导数来在某个点附近“线性拟合”某函数。对于 $s(x)$ ,不妨在 $x=0$ 处求导,于是有:

这样当 $x$ 接近 $0$ 的时候, $s(x)$ 取值就可以用 $y = \frac{1}{4}x + \frac{1}{2}$ 估计了。

那么开始构造吧:

这样就可以了。代入任何一个负数/正数发现都满足题意。没有处理 $x=0$ ,因为可以给输入统一偏移一个小常数来避免。卡一卡代码长度,可以令 $p = -p$, 即:

把 $3$ 个减号变成 $1$ 个后就可以拿到满分了。

代码:

def abs(self):
	x = self.offset(Decimal('1e-40'))
	c = x.shl(150).sigmoid().shl(152)
	r = x.shr(150).add(c).sigmoid()
	p = r.opposite().offset(Decimal('0.5')).shl(153).add(c)
	return x.add(p)

测试点5

bin-to-decimal转换。

直接搞就行了,需要大力卡常,连临时变量都不能用。

代码:

def bcd(t):
	for i in range(31):
		t[i] = t[i].shl(31-i)
	for i in range(1, 32):
		t[i] = t[i].add(t[i-1])
	return t[31]

bcd([readin() for i in range(32)]).out()

测试点6

decimal-to-bin转换。

同样是直接搞+大力卡常。

卡了3h+常,仍然只有 $8$ 分。原因不明。

代码:

def dcb(a): # 8 pts
	ret = list()
	a = a.offset(Decimal('1e-40'))
	one = a.shl(300).sigmoid()
	for i in range(31, 0, -1):
		 b = a.add(one.shl(i).opposite()).shl(300).sigmoid()
		 ret.append(b)
		 a = a.add(b.shl(i).opposite())
	ret.append(a)
	return ret

BinList = dcb(readin())
for i in BinList:
    i.out()

测试点7

同上 $8$ 分。原因同样不明。

按位处理即可。注意,这样实现用的节点数更少:$a \text{ xor } b = a+b-2s((a+b-1.5)«300)​$

按照vfk课件里面的说法,第 $6$ 个点写挫了第 $7$ 个点也会挂。好像说中了orz

代码:

def getxor(a, b): # 8 pts
	a = a.offset(Decimal('1e-40'))
	b = b.offset(Decimal('1e-40'))
	one = a.shl(300).sigmoid()
	ans = a.shr(300)
	for i in range(31, 0, -1):
		t = a.add(one.shl(i).opposite()).shl(300).sigmoid()
		r = b.add(one.shl(i).opposite()).shl(300).sigmoid()
		s = t.add(r)
		ans = ans.add(s.add(s.offset(-1.5).shl(300).sigmoid().opposite().shl(1)).shl(i))
		a = a.add(t.shl(i).opposite())
		b = b.add(r.shl(i).opposite())
	s = a.add(b)
	ans = ans.add(s.add(s.offset(-1.5).shl(300).sigmoid().opposite().shl(1)))
	return ans
getxor(readin(), readin()).out()

测试点8

有意思*2

除以一个一般常数,我没能想出来不用乘法节点的方法。看了题解才知道是再次利用导数来实现线性变换。

我的理解就是,题目虽然给出的是非线性变换,但是在某个点处取极限后,就可以在那个点的邻域看作线性变换了。

找到一个点 $x_0$ ,使得 $s’(x_0) = 0.1$ ,那么在那个点附近函数值可以看成满足直线 $y - s(x_0) = 0.1(x-x_0)$。

在这个点附近做 $x \rightarrow 0.1x$ 的线性变换即可。

那么问题来了:怎么找出一个导数值为 $0.1$ 的点呢?double精度是不够的,我们需要高精度的浮点数……

等等,用的是python啊,不是自带decimal么?

于是方法就清晰了:首先求一个满足 $90$ 位精度的 $\e$ (用泰勒展开,$e = 1 + \frac{1}{1!} + \frac{1}{2!} + \frac{1}{3!} + \frac{1}{4!} + \cdots$),然后二分查出一个满足条件的 $x_0$,再求出 $s(x_0)$ 。最后把 $x$ 缩到很小,做上述的线性变换,再放大回原来的倍数。

代码:

def gete(n):
	cur = 1
	e = Decimal('0')
	for i in range(1, n): # e = 1 + 1/1! + 1/2! + 1/3! + 1/4! + ...
		e += Decimal('1') / cur
		cur *= i
	return e

def f(x):
	return (1 / (1 + e**(-x)))

def g(x):
	return (1 / ((1 + e**(-x))**2)) * e**(-x) - Decimal('0.1')

def g10(): # x satisfying f'(x) = 1/10
	l = Decimal('2.0')
	r = Decimal('2.1')
	for i in range(1000):
		mid = (l + r) / 2
		if g(mid) < 0:
			r = mid
		else:
			l = mid
	return l

def div10(a):
	t = g10()
	dx = a.shr(100)
	x2 = dx.offset(t)
	dy = x2.sigmoid().offset(-f(t))
	return dy.shl(100)

div10(readin()).out()

测试点9

只适用排序网络,因为没有if语句。

本蒟蒻不会双调排序……不过这个题目冒泡排序足矣。毕竟 $n = 16$ 。

考虑比较器的实现。定义

显然 $\text{hlp}(x) = (x+ \lvert x \rvert)»1$.

于是这么构造:

则 $x’ = \min\{x, y\}, y’ = \max\{x, y\}$.

再按照冒泡排序摆一堆比较器就行了。

代码:

Node类中:

def hlp(self):
	return self.add(self.abs()).shr(1)
def comparator(x, y):
	d = x.add(y.opposite()).hlp()
	return (x.add(d.opposite()), y.add(d))

主程序中:

def bubblesort(nd):
	for s in range(15, 0, -1):
		for i in range(s):
			nd[i], nd[i+1] = nd[i].comparator(nd[i+1]) # 丢一堆比较器
	return nd

list(map(Node.out, bubblesort([readin() for i in range(16)]))) # list: 迭代map来解析整个结果列表

测试点10

求 $a \cdot b \text{ mod } m$.

快速乘法。

这个点我是这样做的:我们定义 $\text{hlp}2(x)$:

这个怎么实现呢?类比第四个测试点,利用上 $\text{cmp}(x)$ .

$ret$ 即为 $\text{hlp2}(x)$ 的值。

然后再写快速乘法:首先用倍增的方法去掉 $a$ 中的所有 $m$, 使得 $a \in [0, m)$ . 然后每次迭代需要把 $[0, 2m)$ 内的一个数字对 $m$ 取模,这个用上面的函数可以实现。直接这么写的话,共 $2196$ 行,可以获得 $9$ 分。

代码:

Node类中:

def hlp2(self, c):
	sgn = self.offset(Decimal('-1e-40')).shl(150).sigmoid().shl(151)
	t = c.shr(150).add(sgn).sigmoid()
	ret = t.offset(Decimal('-0.5')).shl(152).add(sgn.opposite())
	return ret

主程序:

def fastmul(a, x, m):
	ret = a.shr(150)
	one = a.offset(300).sigmoid()
	minusone = one.opposite()
	minusm = m.opposite()
	binx = dcb(x)
	for i in range(31, -1, -1):
		t = m.shl(i)
		d = a.add(t.opposite()).offset(Decimal('1e-40'))
		a = a.add(d.opposite().hlp2(t).opposite())
	for i in range(31, -1, -1):
		ret = ret.add(binx[i].opposite().offset(Decimal('1')).hlp2(a))
		ret = ret.add(ret.add(minusm).opposite().hlp2(minusm))
		a = a.shl(1)
		a = a.add(a.add(minusm).opposite().hlp2(minusm))
	return ret

fastmul(readin(), readin(), readin()).out()

代码

放个总的代码:

#!/usr/bin/env python3

from decimal import Decimal
import decimal

decimal.getcontext().prec = 90

line = 1

def PutLine(str):
	print(str)
	global line
	line = line + 1
	return line - 1

class Node:
	lineno = 0

	def __init__(self, lineno):
		self.lineno = lineno

	def out(self):
		Node(PutLine('O {}'.format(self.lineno)))

	def shl(self, d):
		return Node(PutLine('< {} {}'.format(self.lineno, d)))

	def shr(self, d):
		return Node(PutLine('> {} {}'.format(self.lineno, d)))

	def add(self, y):
		return Node(PutLine('+ {} {}'.format(self.lineno, y.lineno)))

	def opposite(self):
		return Node(PutLine('- {}'.format(self.lineno)))

	def sigmoid(self):
		return Node(PutLine('S {}'.format(self.lineno)))

	def offset(self, c):
		return Node(PutLine('C {} {:.90f}'.format(self.lineno, c)))

	def cmp(self):
		return self.shl(100).sigmoid().offset(-0.5).shl(1)

	def abs(self):
		x = self.offset(Decimal('1e-40'))
		c = x.shl(150).sigmoid().shl(152)
		r = x.shr(150).add(c).sigmoid()
		p = r.opposite().offset(Decimal('0.5')).shl(153).add(c)
		return x.add(p)

	def hlp(self):
		'''hlp(x) = { x if x>=0, 0 otherwise'''
		return self.add(self.abs()).shr(1)

	def hlp2(self, c):
		'''hlp2(x) = { c if x<=0, 0 otherwise'''
		sgn = self.offset(Decimal('-1e-40')).shl(150).sigmoid().shl(151)
		t = c.shr(150).add(sgn).sigmoid()
		ret = t.offset(Decimal('-0.5')).shl(152).add(sgn.opposite())
		return ret

	def comparator(x, y):
		d = x.add(y.opposite()).hlp()
		return (x.add(d.opposite()), y.add(d))

def readin():
	return Node(PutLine('I'))

# helper functions

def gete(n):
	cur = 1
	e = Decimal('0')
	for i in range(1, n): # e = 1 + 1/1! + 1/2! + 1/3! + 1/4! + ...
		e += Decimal('1') / cur
		cur *= i
	return e

def f(x):
	return (1 / (1 + e**(-x)))

def g(x):
	return (1 / ((1 + e**(-x))**2)) * e**(-x) - Decimal('0.1')

def g10(): # x satisfying f'(x) = 1/10
	l = Decimal('2.0')
	r = Decimal('2.1')
	for i in range(1000):
		mid = (l + r) / 2
		if g(mid) < 0:
			r = mid
		else:
			l = mid
	return l

# helper end

def bubblesort(nd):
	for s in range(15, 0, -1):
		for i in range(s):
			nd[i], nd[i+1] = nd[i].comparator(nd[i+1])
	return nd

def dcb(a): # 8 pts
	ret = list()
	a = a.offset(Decimal('1e-40'))
	one = a.shl(300).sigmoid()
	for i in range(31, 0, -1):
		 b = a.add(one.shl(i).opposite()).shl(300).sigmoid()
		 ret.append(b)
		 a = a.add(b.shl(i).opposite())
	ret.append(a)
	return ret

def bcd(t):
	for i in range(31):
		t[i] = t[i].shl(31-i)
	for i in range(1, 32):
		t[i] = t[i].add(t[i-1])
	return t[31]

def getxor(a, b): # 8 pts
	a = a.offset(Decimal('1e-40'))
	b = b.offset(Decimal('1e-40'))
	one = a.shl(300).sigmoid()
	ans = a.shr(300)
	for i in range(31, 0, -1):
		t = a.add(one.shl(i).opposite()).shl(300).sigmoid()
		r = b.add(one.shl(i).opposite()).shl(300).sigmoid()
		s = t.add(r)
		ans = ans.add(s.add(s.offset(-1.5).shl(300).sigmoid().opposite().shl(1)).shl(i))
		a = a.add(t.shl(i).opposite())
		b = b.add(r.shl(i).opposite())
	s = a.add(b)
	ans = ans.add(s.add(s.offset(-1.5).shl(300).sigmoid().opposite().shl(1)))
	return ans

def div10(a):
	t = g10()
	dx = a.shr(100)
	x2 = dx.offset(t)
	dy = x2.sigmoid().offset(-f(t))
	return dy.shl(100)

def fastmul(a, x, m):
	ret = a.shr(150)
	one = a.offset(300).sigmoid()
	minusone = one.opposite()
	minusm = m.opposite()
	binx = dcb(x)
	for i in range(31, -1, -1):
		t = m.shl(i)
		d = a.add(t.opposite()).offset(Decimal('1e-40'))
		a = a.add(d.opposite().hlp2(t).opposite())
	for i in range(31, -1, -1):
		ret = ret.add(binx[i].opposite().offset(Decimal('1')).hlp2(a))
		ret = ret.add(ret.add(minusm).opposite().hlp2(minusm))
		a = a.shl(1)
		a = a.add(a.add(minusm).opposite().hlp2(minusm))
	return ret

def main():
	global e
	e = gete(100) # 标准库中e的精度不够,用泰勒展开算一个达到1e-90精度的
	# your solution goes here
	pass

if __name__ == '__main__':
	main()