Issue
I have a problem with deserialization in Java 11 that results in a HashMap
with a key that can't be found. I would appreciate if anyone with more knowledge about the issue could say if my proposed workaround looks ok, or if there is something better I could do.
Consider the following contrived implementation (the relationships in the real problem are a bit more complex and hard to change):
public class Element implements Serializable {
private static long serialVersionUID = 1L;
private final int id;
private final Map<Element, Integer> idFromElement = new HashMap<>();
public Element(int id) {
this.id = id;
}
public void addAll(Collection<Element> elements) {
elements.forEach(e -> idFromElement.put(e, e.id));
}
public Integer idFrom(Element element) {
return idFromElement.get(element);
}
@Override
public int hashCode() {
return id;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (!(obj instanceof Element)) {
return false;
}
Element other = (Element) obj;
return this.id == other.id;
}
}
Then I create an instance that has a reference to itself and serialize and deserialize it:
public static void main(String[] args) {
List<Element> elements = Arrays.asList(new Element(111), new Element(222));
Element originalElement = elements.get(1);
originalElement.addAll(elements);
Storage<Element> storage = new Storage<>();
storage.serialize(originalElement);
Element retrievedElement = storage.deserialize();
if (retrievedElement.idFrom(retrievedElement) == 222) {
System.out.println("ok");
}
}
If I run this code in Java 8 the result is "ok", if I run it in Java 11 the result is a NullPointerException
because retrievedElement.idFrom(retrievedElement)
returns null
.
I put a breakpoint at HashMap.hash()
and noticed that:
- In Java 8, when
idFromElement
is being deserialized andElement(222)
is being added to it, itsid
is 222, so I am able to find it later. - In Java 11, the
id
is not initialized (0 forint
or null if I make it anInteger
), sohash()
is 0 when it's stored in theHashMap
. Later, when I try to retrieve it, theid
is 222, soidFromElement.get(element)
returnsnull
.
I understand that the sequence here is deserialize(Element(222)) -> deserialize(idFromElement) -> put unfinished Element(222) into Map. But, for some reason, in Java 8 id
is already initialized when we get to the last step, while in Java 11 it is not.
The solution I came up with was to make idFromElement
transient and write custom writeObject
and readObject
methods to force idFromElement
to be deserialized after id
:
...
transient private Map<Element, Integer> idFromElement = new HashMap<>();
...
private void writeObject(ObjectOutputStream output) throws IOException {
output.defaultWriteObject();
output.writeObject(idFromElement);
}
@SuppressWarnings("unchecked")
private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException {
input.defaultReadObject();
idFromElement = (HashMap<Element, Integer>) input.readObject();
}
The only reference I was able to find about the order during serialization/deserialization was this:
For serializable classes, the SC_SERIALIZABLE flag is set, the number of fields counts the number of serializable fields and is followed by a descriptor for each serializable field. The descriptors are written in canonical order. The descriptors for primitive typed fields are written first sorted by field name followed by descriptors for the object typed fields sorted by field name. The names are sorted using String.compareTo.
Which is the same in both Java 8 and Java 11 docs, and seems to imply that primitive typed fields should be written first, so I expected there would be no difference.
Implementation of Storage<T>
included for completeness:
public class Storage<T> {
private final ByteArrayOutputStream buffer = new ByteArrayOutputStream();
public void serialize(T object) {
buffer.reset();
try (ObjectOutputStream objectOutputStream = new ObjectOutputStream(buffer)) {
objectOutputStream.writeObject(object);
objectOutputStream.flush();
} catch (Exception ioe) {
ioe.printStackTrace();
}
}
@SuppressWarnings("unchecked")
public T deserialize() {
ByteArrayInputStream byteArrayIS = new ByteArrayInputStream(buffer.toByteArray());
try (ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayIS)) {
return (T) objectInputStream.readObject();
} catch (IOException | ClassNotFoundException e) {
e.printStackTrace();
}
return null;
}
}
Solution
As mentioned in the comments and encouraged by the asker, here are the parts of the code that changed between version 8 and version 11 that I assume to be the reason for the different behavior (based on reading and debugging).
The difference is in the ObjectInputStream
class, in one of its core methods. This is the relevant part of the implementation in Java 8:
private void readSerialData(Object obj, ObjectStreamClass desc)
throws IOException
{
ObjectStreamClass.ClassDataSlot[] slots = desc.getClassDataLayout();
for (int i = 0; i < slots.length; i++) {
ObjectStreamClass slotDesc = slots[i].desc;
if (slots[i].hasData) {
if (obj == null || handles.lookupException(passHandle) != null) {
...
} else {
defaultReadFields(obj, slotDesc);
}
...
}
}
}
/**
* Reads in values of serializable fields declared by given class
* descriptor. If obj is non-null, sets field values in obj. Expects that
* passHandle is set to obj's handle before this method is called.
*/
private void defaultReadFields(Object obj, ObjectStreamClass desc)
throws IOException
{
Class<?> cl = desc.forClass();
if (cl != null && obj != null && !cl.isInstance(obj)) {
throw new ClassCastException();
}
int primDataSize = desc.getPrimDataSize();
if (primVals == null || primVals.length < primDataSize) {
primVals = new byte[primDataSize];
}
bin.readFully(primVals, 0, primDataSize, false);
if (obj != null) {
desc.setPrimFieldValues(obj, primVals);
}
int objHandle = passHandle;
ObjectStreamField[] fields = desc.getFields(false);
Object[] objVals = new Object[desc.getNumObjFields()];
int numPrimFields = fields.length - objVals.length;
for (int i = 0; i < objVals.length; i++) {
ObjectStreamField f = fields[numPrimFields + i];
objVals[i] = readObject0(f.isUnshared());
if (f.getField() != null) {
handles.markDependency(objHandle, passHandle);
}
}
if (obj != null) {
desc.setObjFieldValues(obj, objVals);
}
passHandle = objHandle;
}
...
The method calls defaultReadFields
, which reads the values of the fields. As mentioned in the quoted part of the specification, it first handles the field descriptors of primitive fields. The values that are read for these fields are set immediately after reading them, with
desc.setPrimFieldValues(obj, primVals);
and importantly: This happens before it calls readObject0
for each of the non-primitive fields.
In contrast to that, here is the relevant part of the implementation of Java 11:
private void readSerialData(Object obj, ObjectStreamClass desc)
throws IOException
{
ObjectStreamClass.ClassDataSlot[] slots = desc.getClassDataLayout();
...
for (int i = 0; i < slots.length; i++) {
ObjectStreamClass slotDesc = slots[i].desc;
if (slots[i].hasData) {
if (obj == null || handles.lookupException(passHandle) != null) {
...
} else {
FieldValues vals = defaultReadFields(obj, slotDesc);
if (slotValues != null) {
slotValues[i] = vals;
} else if (obj != null) {
defaultCheckFieldValues(obj, slotDesc, vals);
defaultSetFieldValues(obj, slotDesc, vals);
}
}
...
}
}
...
}
private class FieldValues {
final byte[] primValues;
final Object[] objValues;
FieldValues(byte[] primValues, Object[] objValues) {
this.primValues = primValues;
this.objValues = objValues;
}
}
/**
* Reads in values of serializable fields declared by given class
* descriptor. Expects that passHandle is set to obj's handle before this
* method is called.
*/
private FieldValues defaultReadFields(Object obj, ObjectStreamClass desc)
throws IOException
{
Class<?> cl = desc.forClass();
if (cl != null && obj != null && !cl.isInstance(obj)) {
throw new ClassCastException();
}
byte[] primVals = null;
int primDataSize = desc.getPrimDataSize();
if (primDataSize > 0) {
primVals = new byte[primDataSize];
bin.readFully(primVals, 0, primDataSize, false);
}
Object[] objVals = null;
int numObjFields = desc.getNumObjFields();
if (numObjFields > 0) {
int objHandle = passHandle;
ObjectStreamField[] fields = desc.getFields(false);
objVals = new Object[numObjFields];
int numPrimFields = fields.length - objVals.length;
for (int i = 0; i < objVals.length; i++) {
ObjectStreamField f = fields[numPrimFields + i];
objVals[i] = readObject0(f.isUnshared());
if (f.getField() != null) {
handles.markDependency(objHandle, passHandle);
}
}
passHandle = objHandle;
}
return new FieldValues(primVals, objVals);
}
...
An inner class, FieldValues
, has been introduced. The defaultReadFields
method now only reads the field values, and returns them as a FieldValues
object. Afterwards, the returned values are assigned to the fields, by passing this FieldValues
object to a newly introduced defaultSetFieldValues
method, which internally does the desc.setPrimFieldValues(obj, primValues)
call that originally was done immediately after the primitive values had been read.
To emphasize this again: The defaultReadFields
method first reads the primitive field values. Then it reads the non-primitive field values. But it does so before the primitive field values have been set!
This new process interferes with the deserialization method of HashMap
. Again, the relevant part is shown here:
private void readObject(java.io.ObjectInputStream s)
throws IOException, ClassNotFoundException {
...
int mappings = s.readInt(); // Read number of mappings (size)
if (mappings < 0)
throw new InvalidObjectException("Illegal mappings count: " +
mappings);
else if (mappings > 0) { // (if zero, use defaults)
...
Node<K,V>[] tab = (Node<K,V>[])new Node[cap];
table = tab;
// Read the keys and values, and put the mappings in the HashMap
for (int i = 0; i < mappings; i++) {
@SuppressWarnings("unchecked")
K key = (K) s.readObject();
@SuppressWarnings("unchecked")
V value = (V) s.readObject();
putVal(hash(key), key, value, false, false);
}
}
}
It reads the key- and value objects, one by one, and puts them into the table, by computing the hash of the key and using the internal putVal
method. This is the same method that is used when manually populating the map (i.e. when it is filled programmatically, and not deserialized).
Holger already gave a hint in the comments why this is necessary: There is no guarantee that the hash code of the deserialized keys will be the same as before the serialization. So blindly "restoring the original array" could basically lead to objects being stored in the table under a wrong hash code.
But here, the opposite happens: The keys (i.e. the objects of type Element
) are deserialized. They contain the idFromElement
map, which in turn contains the Element
objects. These elements are put into the map, while the Element
objects are still in the process of being deserialized, using the putVal
method. But due to the changed order in ObjectInputStream
, this is done before the primitive value of the id
field (which determines the hash code) has been set. So the objects are stored using hash code 0
, and later, the id
values is assigned (e.g. the value 222
), causing the objects to end up in the table under a hash code that they actually no longer have.
Now, on a more abstract level, this was already clear from the observed behavior. Therefore, the original question was not "What is going on here???", but
if my proposed workaround looks ok, or if there is something better I could do.
I think that the workaround could be OK, but would hesitate to say that nothing could go wrong there. It's complicated.
As of the second part: Something better could be to file a bug report at the Java Bug Database, because the new behavior is clearly broken. It may be hard to point out a specification that is violated, but the deserialized map is certainly inconsistent, and this is not acceptable.
(Yes, I could also file a bug report, but think that more research might be necessary in order to make sure it is written properly, not a duplicate, etc....)
Answered By - Marco13
Answer Checked By - David Marino (JavaFixing Volunteer)