Select Git revision
BufferDataSourceTest.cpp
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
test.py 1.03 KiB
import torch
import numpy as np
from src.data.generators import KSAT_Generator
from src.csp.csp_data import CSP_Data
# ksat_generator = KSAT_Generator(min_n=3, max_n=3, min_k=2, max_k=2, min_alpha=1.0, max_alpha=1.0)
#
# # Create a random Boolean satisfiability instance
# csp_data_instance = ksat_generator.create_random_instance()
#
# logits = torch.ones((csp_data_instance.num_val,), device=csp_data_instance.device, dtype=torch.float32)
# assignment, _ = csp_data_instance.hard_assign_sample(logits)
#
# print("Generated Assignment:", assignment)
# is_satisfied = csp_data_instance.constraint_is_sat(assignment)
x = torch.as_tensor([1, 0, 1, 0, 2, 2, 0, 2, 0, 1]).view(2, 5)
a = x.view(2, 1, -1)
b = x.view(2, -1, 1)
print(a)
c = a == b
#
c = c.prod(dim=0)
rep = torch.tril(c, -1).max(dim=1)[0]
u = 1 - rep
rep[rep == 1] = -1
print(torch.logical_or(u, rep))
# a = torch.randint(2, size=(2,2,5))
# print(a)
# indices = torch.nonzero(a == 1, as_tuple=False)
# Flatten the indices to a 1D tensor
# flattened_indices = indices.view(-1)
# print(indices)