ccf吧 关注:1,671贴子:3,243
  • 1回复贴,共1

【求助】CSP 第五题

只看楼主收藏回复

求助,实在是不知道哪错了,自己写的一部分样例,加乘应该是没问题的,加转置就有问题了。优先级是转置==乘>加。
n, m = map(int, input().split())
# A = [[0, 0, 0] for _ in range(n + 1)]
mod = 1000000007
def pp():
i = 1
while i < 4 * n:
print(node[i:i << 1])
i <<= 1
class Node:
def __init__(self, l, r):
self.l = l
self.r = r
self.s = [0, 0, 0]
self.lza = [0, 0, 0]
self.lzm = 1
self.lztn = 0
def __repr__(self):
return "s:%s,a:%s,m:%d,tn:%d" % (str(self.s), str(self.lza), self.lzm, self.lztn)
node = [0] * (n << 2)
def make(l, r, i):
node[i] = Node(l, r)
if l == r:
return
mid = l + r >> 1
make(l, mid, i << 1)
make(mid + 1, r, i << 1 | 1)
make(1, n, 1)
def up(i):
for j in range(3):
node[i].s[j] = (node[i << 1].s[j] + node[i << 1 | 1].s[j]) % mod
def down(i):
le = i << 1
ri = le | 1
if node[i].lztn != 0:
tmp = node[i].s[::]
tlza = node[i].lza[::]
ls = node[le].s[::]
rs = node[ri].s[::]
llza = node[le].lza[::]
rlza = node[ri].lza[::]
for j in range(3):
# TODO
node[i].s[j] = tmp[(j + node[i].lztn) % 3]
node[i].lza[j] = tlza[(j + node[i].lztn) % 3]
node[le].s[j] = ls[(j + node[i].lztn) % 3]
node[le].lza[j] = llza[(j + node[i].lztn) % 3]
node[ri].s[j] = rs[(j + node[i].lztn) % 3]
node[ri].lza[j] = rlza[(j + node[i].lztn) % 3]
node[le].lztn = (node[le].lztn + node[i].lztn) % 3
node[ri].lztn = (node[ri].lztn + node[i].lztn) % 3
node[i].lztn = 0
if node[i].lzm != 1:
lzm = node[i].lzm
for j in range(3):
node[le].lza[j] = (node[le].lza[j] * lzm) % mod
node[ri].lza[j] = (node[ri].lza[j] * lzm) % mod
node[le].lzm = (node[le].lzm * lzm) % mod
node[ri].lzm = (node[ri].lzm * lzm) % mod
for j in range(3):
node[le].s[j] = (node[le].s[j] * lzm) % mod
node[ri].s[j] = (node[ri].s[j] * lzm) % mod
node[i].lzm = 1
if node[i].lza.count(0) < 3:
llen = node[le].r - node[le].l + 1
rlen = node[ri].r - node[ri].l + 1
for j in range(3):
if node[i].lza[j] == 0: continue
node[le].lza[j] = (node[le].lza[j] + node[i].lza[j]) % mod
node[ri].lza[j] = (node[ri].lza[j] + node[i].lza[j]) % mod
node[le].s[j] = (node[le].s[j] + llen * node[i].lza[j]) % mod
node[ri].s[j] = (node[ri].s[j] + rlen * node[i].lza[j]) % mod
node[i].lza[j] = 0
def add(L, R, *kw, i=1):
if L <= node[i].l and R >= node[i].r:
if len(kw) == 0:
node[i].lztn = (node[i].lztn + 1) % 3
tmp = node[i].s[::]
tlza = node[i].lza[::]
for j in range(3):
# TODO
node[i].s[j] = tmp[(j + 1) % 3]
node[i].lza[j] = tlza[(j + 1) % 3]
elif len(kw) == 1:
k = kw[0]
node[i].lzm = (node[i].lzm * k) % mod
for j in range(3):
node[i].lza[j] = (node[i].lza[j] * k) % mod
node[i].s[j] = (node[i].s[j] * k) % mod
elif len(kw) == 3:
ilen = node[i].r - node[i].l + 1
a, b, c = kw
node[i].lza[0] = (node[i].lza[0] + a) % mod
node[i].lza[1] = (node[i].lza[1] + b) % mod
node[i].lza[2] = (node[i].lza[2] + c) % mod
node[i].s[0] = (node[i].s[0] + a * ilen) % mod
node[i].s[1] = (node[i].s[1] + b * ilen) % mod
node[i].s[2] = (node[i].s[2] + c * ilen) % mod
return
down(i)
mid = node[i].l + node[i].r >> 1
if L <= mid:
add(L, R, *kw, i=i << 1)
if R > mid:
add(L, R, *kw, i=i << 1 | 1)
up(i)
SUM = [0, 0, 0]
def query(L, R, i=1):
if L <= node[i].l and R >= node[i].r:
for j in range(3):
# print("-----:", i, j, node[i].s[j])
SUM[j] += node[i].s[j]
else:
down(i)
mid = node[i].l + node[i].r >> 1
if L <= mid:
query(L, R, i << 1)
if R > mid:
query(L, R, i << 1 | 1)
for _ in range(m):
t = input().rstrip().split()
if t[0] == "1":
l, r, *kw = map(int, t[1:])
add(l, r, *kw)
elif t[0] == "4":
l, r = map(int, t[1:])
SUM = [0, 0, 0]
query(l, r)
t = 0
for j in range(3):
t += SUM[j] ** 2
print(t % mod)
elif t[0] == "2":
l, r, *k = map(int, t[1:])
add(l, r, *k)
elif t[0] == "3":
l, r = map(int, t[1:])
add(l, r)
# pp()


IP属地:广东1楼2021-01-13 09:38回复
    第21次第五题,好家伙,代码格式没了


    IP属地:广东2楼2021-01-13 09:42
    回复